fashex 0.0.3

Conversion from bytes to hexadecimal string.
Documentation
//! Optimized implementations based on unstable [`core::simd`] APIs.

#![allow(unsafe_code, reason = "XXX")]
#![allow(unsafe_op_in_unsafe_fn, reason = "XXX")]

use core::mem::MaybeUninit;
use core::simd::prelude::*;
use core::slice;

use crate::backend::generic::{
    decode_generic_unchecked, encode_generic_unchecked, validate_generic,
};
use crate::error::InvalidInput;
use crate::util::digits16;

/// ## Safety
///
/// We assume that:
///
/// 1. `src.len() == dst.len()`.
pub(crate) unsafe fn encode_simd128_unchecked<const UPPER: bool>(
    mut src: &[u8],
    mut dst: &mut [[MaybeUninit<u8>; 2]],
) {
    /// Process 16 bytes of input, and produce 16 * 2 bytes of output.
    const BATCH: usize = size_of::<Simd<u8, 16>>();

    if src.len() >= BATCH {
        let m = Simd::splat(0b_0000_1111);
        let lut = Simd::from_array(*digits16::<UPPER>());

        while src.len() >= BATCH {
            // let [byte @ u8; 16]
            let chunk = Simd::<u8, 16>::from_slice(src);

            // let [hi; 16] = [byte >> 4; 16];
            let mut hi = chunk >> 4;
            // let [lo; 16] = [byte & 0b_0000_1111; 16];
            let mut lo = chunk & m;

            // let [hi; 16] = [lut[hi]; 16];
            lo = lut.swizzle_dyn(lo);
            // let [lo; 16] = [lut[lo]; 16];
            hi = lut.swizzle_dyn(hi);

            // Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
            let (out0, out1) = Simd::<u8, 16>::interleave(hi, lo);

            // Store the result.
            {
                let dst = dst.as_mut_ptr().cast::<Simd<u8, 16>>();
                out0.copy_to_slice(slice::from_raw_parts_mut(dst.cast::<u8>(), BATCH));
                out1.copy_to_slice(slice::from_raw_parts_mut(dst.add(1).cast::<u8>(), BATCH));
            }

            src = &src[BATCH..];
            dst = dst.get_unchecked_mut(BATCH..);
        }
    }

    encode_generic_unchecked::<UPPER>(src, dst);
}

pub(crate) unsafe fn validate_simd128(mut src: &[[u8; 2]]) -> Result<(), InvalidInput> {
    /// Process 16 bytes of input (8 * 2 bytes).
    const BATCH: usize = size_of::<Simd<u8, 16>>() / 2;

    if src.len() >= BATCH {
        let ascii_0 = Simd::splat(b'0');
        let ascii_digit_thresh_p1 = Simd::splat(9 + 1);

        let ascii_a = Simd::splat(b'a');
        let ascii_alphabetic_thresh_p1 = Simd::splat(5 + 1);

        let ascii_case_mask = Simd::splat(0b_0010_0000);

        while src.len() >= BATCH {
            let chunk: Simd<u8, 16> =
                Simd::from_slice(slice::from_raw_parts(src.as_ptr().cast::<u8>(), BATCH * 2));

            // let is_digit = (char.wrapping_sub('0') <= 9) = (char.wrapping_sub('0') < 10);
            let is_digit = (chunk - ascii_0).simd_lt(ascii_digit_thresh_p1);

            // let char_lowercased = char | 0b_0010_0000;
            // let is_alphabetic = (char_lowercased.wrapping_sub('a') <= 5)
            //                   = (char_lowercased.wrapping_sub('a') < 6);
            let is_alphabetic =
                ((chunk | ascii_case_mask) - ascii_a).simd_lt(ascii_alphabetic_thresh_p1);

            // let is_valid = is_digit | is_alphabetic;
            let is_valid = is_digit | is_alphabetic;

            if !is_valid.all() {
                return Err(InvalidInput);
            }

            src = &src[BATCH..];
        }
    }

    validate_generic(src)
}

/// ## Safety
///
/// We assume that:
///
/// 1. `src.len() == dst.len()`.
/// 2. The input is valid hexadecimal string.
pub(crate) unsafe fn decode_simd128_unchecked(
    mut src: &[[u8; 2]],
    mut dst: &mut [MaybeUninit<u8>],
) {
    /// Process 16 * 2 bytes of input, and produce 16 bytes of output.
    const BATCH: usize = size_of::<Simd<u8, 16>>();

    #[inline]
    fn unhex(value: Simd<u8, 16>) -> Simd<u8, 16> {
        // value >> 6
        let sr6 = value >> 6;

        // value & 0b_0000_1111, take 4 least significant bits
        let and15 = value & Simd::splat(0b_0000_1111);

        // sr6 * 9 = sr6 * 8 + sr6 = sr6 << 3 + sr6
        // let mul = sr6 * 9;
        let mul = (sr6 << 3) + sr6;

        // (value >> 6) + (value & 0b_0000_1111)
        mul + and15
    }

    #[inline]
    /// `(hi << 4) | lo`
    fn nib2byte(hi: Simd<u8, 16>, lo: Simd<u8, 16>) -> Simd<u8, 16> {
        (hi << 4) | lo
    }

    while src.len() >= BATCH {
        // let [[hi,lo]; 0-7] = chunk0
        let chunk0: Simd<u8, 16> = {
            #[allow(unsafe_code, reason = "XXX")]
            let chunk = unsafe { slice::from_raw_parts(src.as_ptr().cast::<u8>(), BATCH) };

            Simd::from_slice(chunk)
        };
        // let [[hi,lo]; 8-15] = chunk1
        let chunk1: Simd<u8, 16> = {
            #[allow(unsafe_code, reason = "XXX")]
            let chunk =
                unsafe { slice::from_raw_parts(src.as_ptr().cast::<u8>().add(BATCH), BATCH) };

            Simd::from_slice(chunk)
        };

        let hi = simd_swizzle!(
            chunk0,
            chunk1,
            [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
        );
        let lo = simd_swizzle!(
            chunk0,
            chunk1,
            [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31]
        );

        let hi = unhex(hi);
        let lo = unhex(lo);

        let out = nib2byte(hi, lo);

        // Store the result.
        out.copy_to_slice(unsafe {
            slice::from_raw_parts_mut(dst.as_mut_ptr().cast::<u8>(), BATCH)
        });

        src = &src[BATCH..];
        dst = unsafe { dst.get_unchecked_mut(BATCH..) };
    }

    let _ = {
        #[allow(unsafe_code, reason = "XXX")]
        unsafe {
            decode_generic_unchecked::<true>(src, dst)
        }
    };
}

#[cfg(test)]
mod smoking {
    use alloc::string::String;
    use alloc::vec;
    use core::mem::MaybeUninit;
    use core::{slice, str};

    use super::{decode_simd128_unchecked, encode_simd128_unchecked, validate_simd128};
    use crate::util::{DIGITS_LOWER_16, DIGITS_UPPER_16};

    macro_rules! test {
        (
            Encode = $encode_f:ident;
            Validate = $($validate_f:ident),*;
            Decode = $($decode_f:ident),*;
            Case = $i:expr
        ) => {{
            let input = $i;

            let expected_lower = input
                .iter()
                .flat_map(|b| [
                    DIGITS_LOWER_16[(*b >> 4) as usize] as char,
                    DIGITS_LOWER_16[(*b & 0b1111) as usize] as char,
                ])
                .collect::<String>();
            let expected_upper = input
                .iter()
                .flat_map(|b| [
                    DIGITS_UPPER_16[(*b >> 4) as usize] as char,
                    DIGITS_UPPER_16[(*b & 0b1111) as usize] as char,
                ])
                .collect::<String>();

            let mut output_lower = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];
            let mut output_upper = vec![[MaybeUninit::<u8>::uninit(); 2]; input.len()];

            unsafe {
                $encode_f::<false>(input, &mut output_lower);
                $encode_f::<true>(input, &mut output_upper);
            }

            let output_lower = unsafe {
                slice::from_raw_parts(
                    output_lower.as_ptr().cast::<[u8; 2]>(),
                    output_lower.len(),
                )
            };
            let output_upper = unsafe {
                slice::from_raw_parts(
                    output_upper.as_ptr().cast::<[u8; 2]>(),
                    output_upper.len(),
                )
            };

            assert_eq!(
                output_lower.as_flattened(),
                expected_lower.as_bytes(),
                "Encode error, expect \"{expected_lower}\", got \"{}\" ({:?})",
                str::from_utf8(output_lower.as_flattened()).unwrap_or("<invalid utf-8>"),
                output_lower.as_flattened()
            );
            assert_eq!(
                output_upper.as_flattened(),
                expected_upper.as_bytes(),
                "Encode error, expect \"{expected_upper}\", got \"{}\" ({:?})",
                str::from_utf8(output_upper.as_flattened()).unwrap_or("<invalid utf-8>"),
                output_upper.as_flattened()
            );

            $(
                unsafe {
                    $validate_f(output_lower)
                        .unwrap_or_else(|_| panic!("validation failed for {}", stringify!($validate_f)));
                    $validate_f(output_upper)
                        .unwrap_or_else(|_| panic!("validation failed for {}", stringify!($validate_f)));
                }
            )*

            $({
                let mut decoded_lower = vec![MaybeUninit::<u8>::uninit(); input.len()];
                let mut decoded_upper = vec![MaybeUninit::<u8>::uninit(); input.len()];

                unsafe {
                    $decode_f(output_lower, &mut decoded_lower);
                    $decode_f(output_upper, &mut decoded_upper);

                    assert_eq!(
                        decoded_lower.assume_init_ref(),
                        input,
                        "Decode error for {}, expect {:?}, got {:?}",
                        stringify!($decode_f),
                        input,
                        decoded_lower.assume_init_ref()
                    );
                    assert_eq!(
                        decoded_upper.assume_init_ref(),
                        input,
                        "Decode error for {}, expect {:?}, got {:?}",
                        stringify!($decode_f),
                        input,
                        decoded_upper.assume_init_ref()
                    );
                }
            })*
        }};
    }

    #[test]
    #[cfg_attr(miri, ignore)]
    fn test_simd128() {
        const CASE: &[u8; 17] = &[
            0x12, 0x77, 0x4C, 0x16, 0x16, 0x2B, 0x99, 0x97, 0x37, 0x62, 0x24, 0x24, 0x36, 0x83,
            0xA4, 0xF1, 0xDD,
        ];

        test! {
            Encode = encode_simd128_unchecked;
            Validate = validate_simd128;
            Decode = decode_simd128_unchecked;
            Case = &[]
        }

        test! {
            Encode = encode_simd128_unchecked;
            Validate = validate_simd128;
            Decode = decode_simd128_unchecked;
            Case = &CASE[..15]
        }

        test! {
            Encode = encode_simd128_unchecked;
            Validate = validate_simd128;
            Decode = decode_simd128_unchecked;
            Case = &CASE[..16]
        }

        test! {
            Encode = encode_simd128_unchecked;
            Validate = validate_simd128;
            Decode = decode_simd128_unchecked;
            Case = &CASE[..17]
        };
    }

    #[test]
    #[cfg_attr(miri, ignore)]
    fn test_validation() {
        for l in [15, 16, 17, 31, 32, 33, 63, 64, 65] {
            for c in 0..=255 {
                let mut bytes = vec![b'a'; l * 2];

                bytes[l] = c;

                let bytes = unsafe { bytes.as_chunks_unchecked() };

                if c.is_ascii_hexdigit() {
                    unsafe {
                        assert!(
                            validate_simd128(bytes).is_ok(),
                            "simd128 validation failed for char `{}` (l={l})",
                            c as char
                        );
                    }
                } else {
                    unsafe {
                        assert!(
                            validate_simd128(bytes).is_err(),
                            "simd128 validation should have failed for byte {c} (l={l})"
                        );
                    }
                }
            }
        }
    }
}