#![allow(unused_unsafe)]
use super::super::common;
use crate::core::dictionary::Dictionary;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HexVariant {
Uppercase,
Lowercase,
}
pub fn encode(data: &[u8], _dictionary: &Dictionary, variant: HexVariant) -> Option<String> {
let output_len = data.len() * 2;
let mut result = String::with_capacity(output_len);
#[cfg(target_arch = "aarch64")]
unsafe {
encode_neon_impl(data, variant, &mut result);
}
Some(result)
}
pub fn decode(encoded: &str, _variant: HexVariant) -> Option<Vec<u8>> {
let encoded_bytes = encoded.as_bytes();
if !encoded_bytes.len().is_multiple_of(2) {
return None;
}
let output_len = encoded_bytes.len() / 2;
let mut result = Vec::with_capacity(output_len);
#[cfg(target_arch = "aarch64")]
unsafe {
if !decode_neon_impl(encoded_bytes, &mut result) {
return None;
}
}
Some(result)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn encode_neon_impl(data: &[u8], variant: HexVariant, result: &mut String) {
use std::arch::aarch64::*;
const BLOCK_SIZE: usize = 16;
if data.len() < BLOCK_SIZE {
encode_scalar_remainder(data, variant, result);
return;
}
let (num_rounds, simd_bytes) = common::calculate_blocks(data.len(), BLOCK_SIZE);
let lut = match variant {
HexVariant::Uppercase => unsafe {
vld1q_u8(
[
b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'A', b'B', b'C',
b'D', b'E', b'F',
]
.as_ptr(),
)
},
HexVariant::Lowercase => unsafe {
vld1q_u8(
[
b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', b'b', b'c',
b'd', b'e', b'f',
]
.as_ptr(),
)
},
};
let mask_0f = unsafe { vdupq_n_u8(0x0F) };
let mut offset = 0;
for _ in 0..num_rounds {
let output_buf = unsafe {
let input_vec = vld1q_u8(data.as_ptr().add(offset));
let hi_nibbles = vandq_u8(vshrq_n_u8(input_vec, 4), mask_0f);
let lo_nibbles = vandq_u8(input_vec, mask_0f);
let hi_ascii = vqtbl1q_u8(lut, hi_nibbles);
let lo_ascii = vqtbl1q_u8(lut, lo_nibbles);
let result_lo = vzip1q_u8(hi_ascii, lo_ascii);
let result_hi = vzip2q_u8(hi_ascii, lo_ascii);
let mut output_buf = [0u8; 32];
vst1q_u8(output_buf.as_mut_ptr(), result_lo);
vst1q_u8(output_buf.as_mut_ptr().add(16), result_hi);
output_buf
};
for &byte in &output_buf {
result.push(byte as char);
}
offset += BLOCK_SIZE;
}
if simd_bytes < data.len() {
encode_scalar_remainder(&data[simd_bytes..], variant, result);
}
}
fn encode_scalar_remainder(data: &[u8], variant: HexVariant, result: &mut String) {
let chars = match variant {
HexVariant::Uppercase => b"0123456789ABCDEF",
HexVariant::Lowercase => b"0123456789abcdef",
};
for &byte in data {
let hi = (byte >> 4) as usize;
let lo = (byte & 0x0F) as usize;
result.push(chars[hi] as char);
result.push(chars[lo] as char);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn decode_neon_impl(encoded: &[u8], result: &mut Vec<u8>) -> bool {
use std::arch::aarch64::*;
const INPUT_BLOCK_SIZE: usize = 32;
const OUTPUT_BLOCK_SIZE: usize = 16;
let (num_rounds, simd_bytes) = common::calculate_blocks(encoded.len(), INPUT_BLOCK_SIZE);
for round in 0..num_rounds {
let offset = round * INPUT_BLOCK_SIZE;
let (is_valid, output_buf) = unsafe {
let input_lo = vld1q_u8(encoded.as_ptr().add(offset));
let input_hi = vld1q_u8(encoded.as_ptr().add(offset + 16));
let hi_chars = vuzp1q_u8(input_lo, input_hi);
let lo_chars = vuzp2q_u8(input_lo, input_hi);
let hi_vals = decode_nibble_chars_neon(hi_chars);
let lo_vals = decode_nibble_chars_neon(lo_chars);
let hi_valid = vmaxvq_u8(hi_vals) < 16;
let lo_valid = vmaxvq_u8(lo_vals) < 16;
if !hi_valid || !lo_valid {
(false, [0u8; 16])
} else {
let packed = vorrq_u8(vshlq_n_u8(hi_vals, 4), lo_vals);
let mut output_buf = [0u8; 16];
vst1q_u8(output_buf.as_mut_ptr(), packed);
(true, output_buf)
}
};
if !is_valid {
return false;
}
result.extend_from_slice(&output_buf[0..OUTPUT_BLOCK_SIZE]);
}
if simd_bytes < encoded.len() && !decode_scalar_remainder(&encoded[simd_bytes..], result) {
return false;
}
true
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn decode_nibble_chars_neon(
chars: std::arch::aarch64::uint8x16_t,
) -> std::arch::aarch64::uint8x16_t {
use std::arch::aarch64::*;
unsafe {
let is_digit = vandq_u8(
vcgtq_u8(chars, vdupq_n_u8(0x2F)), vcgeq_u8(vdupq_n_u8(0x3A), chars), );
let is_upper = vandq_u8(
vcgtq_u8(chars, vdupq_n_u8(0x40)), vcgeq_u8(vdupq_n_u8(0x47), chars), );
let is_lower = vandq_u8(
vcgtq_u8(chars, vdupq_n_u8(0x60)), vcgeq_u8(vdupq_n_u8(0x67), chars), );
let digit_vals = vandq_u8(is_digit, vsubq_u8(chars, vdupq_n_u8(0x30)));
let upper_vals = vandq_u8(is_upper, vsubq_u8(chars, vdupq_n_u8(0x37)));
let lower_vals = vandq_u8(is_lower, vsubq_u8(chars, vdupq_n_u8(0x57)));
let valid_vals = vorrq_u8(vorrq_u8(digit_vals, upper_vals), lower_vals);
let is_valid = vorrq_u8(vorrq_u8(is_digit, is_upper), is_lower);
vorrq_u8(
vandq_u8(is_valid, valid_vals),
vbicq_u8(vdupq_n_u8(0xFF), is_valid),
)
}
}
fn decode_scalar_remainder(data: &[u8], result: &mut Vec<u8>) -> bool {
if !data.len().is_multiple_of(2) {
return false;
}
for chunk in data.chunks_exact(2) {
let hi = match decode_hex_char(chunk[0]) {
Some(v) => v,
None => return false,
};
let lo = match decode_hex_char(chunk[1]) {
Some(v) => v,
None => return false,
};
result.push((hi << 4) | lo);
}
true
}
fn decode_hex_char(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => Some(c - b'0'),
b'A'..=b'F' => Some(c - b'A' + 10),
b'a'..=b'f' => Some(c - b'a' + 10),
_ => None,
}
}
pub fn identify_hex_variant(dict: &Dictionary) -> Option<HexVariant> {
if dict.base() != 16 {
return None;
}
match dict.encode_digit(10)? {
'A' => Some(HexVariant::Uppercase),
'a' => Some(HexVariant::Lowercase),
_ => None,
}
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::core::config::EncodingMode;
use crate::core::dictionary::Dictionary;
fn make_hex_dict_upper() -> Dictionary {
let chars: Vec<char> = "0123456789ABCDEF".chars().collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, None).unwrap()
}
fn make_hex_dict_lower() -> Dictionary {
let chars: Vec<char> = "0123456789abcdef".chars().collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, None).unwrap()
}
#[test]
fn test_encode_uppercase() {
let dictionary = make_hex_dict_upper();
let test_data = b"Hello, World!";
if let Some(result) = encode(test_data, &dictionary, HexVariant::Uppercase) {
let expected = "48656C6C6F2C20576F726C6421";
assert_eq!(result, expected);
}
}
#[test]
fn test_encode_lowercase() {
let dictionary = make_hex_dict_lower();
let test_data = b"Hello, World!";
if let Some(result) = encode(test_data, &dictionary, HexVariant::Lowercase) {
let expected = "48656c6c6f2c20576f726c6421".to_lowercase();
assert_eq!(result, expected);
}
}
#[test]
fn test_decode_uppercase() {
let encoded = "48656C6C6F2C20576F726C6421";
if let Some(decoded) = decode(encoded, HexVariant::Uppercase) {
assert_eq!(decoded, b"Hello, World!");
} else {
panic!("Decode failed");
}
}
#[test]
fn test_decode_lowercase() {
let encoded = "48656c6c6f2c20576f726c6421";
if let Some(decoded) = decode(encoded, HexVariant::Lowercase) {
assert_eq!(decoded, b"Hello, World!");
} else {
panic!("Decode failed");
}
}
#[test]
fn test_decode_mixed_case() {
let encoded = "48656C6c6F2c20576F726C6421";
if let Some(decoded) = decode(encoded, HexVariant::Uppercase) {
assert_eq!(decoded, b"Hello, World!");
} else {
panic!("Decode failed");
}
}
#[test]
fn test_round_trip() {
let dictionary = make_hex_dict_upper();
for len in 0..100 {
let original: Vec<u8> = (0..len).map(|i| (i * 7) as u8).collect();
if let Some(encoded) = encode(&original, &dictionary, HexVariant::Uppercase)
&& let Some(decoded) = decode(&encoded, HexVariant::Uppercase)
{
assert_eq!(decoded, original, "Round-trip failed at length {}", len);
}
}
}
#[test]
fn test_decode_invalid_chars() {
let invalid_cases = [
"4865ZZ", "48656", "48656G6C", ];
for &encoded in &invalid_cases {
assert_eq!(
decode(encoded, HexVariant::Uppercase),
None,
"Should reject: {}",
encoded
);
}
}
#[test]
fn test_identify_variant() {
let upper_dict = make_hex_dict_upper();
assert_eq!(
identify_hex_variant(&upper_dict),
Some(HexVariant::Uppercase)
);
let lower_dict = make_hex_dict_lower();
assert_eq!(
identify_hex_variant(&lower_dict),
Some(HexVariant::Lowercase)
);
}
#[test]
fn test_encode_edge_cases() {
let dictionary = make_hex_dict_upper();
if let Some(result) = encode(&[], &dictionary, HexVariant::Uppercase) {
assert_eq!(result, "");
}
if let Some(result) = encode(&[0xFF], &dictionary, HexVariant::Uppercase) {
assert_eq!(result, "FF");
}
if let Some(result) = encode(&[0x00, 0x00, 0x00], &dictionary, HexVariant::Uppercase) {
assert_eq!(result, "000000");
}
}
}