#[derive(Debug, Clone)]
pub struct DecodeResult {
pub decoded: Option<Vec<u8>>,
pub encoding: EncodingType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncodingType {
None,
Base64,
Hex,
}
#[inline]
pub fn speculative_decode(s: &str) -> DecodeResult {
let bytes = s.as_bytes();
if bytes.len() >= 8 && bytes.len() % 2 == 0 && is_likely_hex(bytes) {
if let Some(decoded) = try_decode_hex(bytes) {
return DecodeResult {
decoded: Some(decoded),
encoding: EncodingType::Hex,
};
}
}
if bytes.len() >= 8 && is_likely_base64(bytes) {
if let Some(decoded) = try_decode_base64(bytes) {
return DecodeResult {
decoded: Some(decoded),
encoding: EncodingType::Base64,
};
}
}
DecodeResult {
decoded: None,
encoding: EncodingType::None,
}
}
#[inline]
fn is_likely_hex(bytes: &[u8]) -> bool {
crate::simd::is_all_hex(unsafe { std::str::from_utf8_unchecked(bytes) })
}
#[inline]
fn is_likely_base64(bytes: &[u8]) -> bool {
if bytes.len() < 4 {
return false;
}
crate::simd::is_all_base64_chars(unsafe { std::str::from_utf8_unchecked(bytes) })
}
#[inline]
fn try_decode_hex(bytes: &[u8]) -> Option<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && bytes.len() >= 32 {
return unsafe { decode_hex_avx2(bytes) };
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") && bytes.len() >= 16 {
return unsafe { decode_hex_neon(bytes) };
}
}
decode_hex_scalar(bytes)
}
#[inline]
fn decode_hex_scalar(bytes: &[u8]) -> Option<Vec<u8>> {
let mut result = Vec::with_capacity(bytes.len() / 2);
for chunk in bytes.chunks_exact(2) {
let high = hex_digit_value(chunk[0])?;
let low = hex_digit_value(chunk[1])?;
result.push((high << 4) | low);
}
Some(result)
}
#[inline]
fn hex_digit_value(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn decode_hex_avx2(bytes: &[u8]) -> Option<Vec<u8>> {
use std::arch::x86_64::*;
let mut result = Vec::with_capacity(bytes.len() / 2);
let chunks = bytes.len() / 32;
let ptr = bytes.as_ptr();
let digit_lut = _mm256_setr_epi8(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, );
let alpha_adjust = _mm256_set1_epi8(10); let nine = _mm256_set1_epi8(b'9' as i8);
let lower_a = _mm256_set1_epi8(b'a' as i8);
let upper_a = _mm256_set1_epi8(b'A' as i8);
let mask_0f = _mm256_set1_epi8(0x0F);
for i in 0..chunks {
let chunk = _mm256_loadu_si256(ptr.add(i * 32) as *const __m256i);
let is_digit = _mm256_cmpgt_epi8(_mm256_set1_epi8(b':' as i8), chunk); let is_digit = _mm256_and_si256(
is_digit,
_mm256_cmpgt_epi8(chunk, _mm256_set1_epi8(b'0' as i8 - 1)),
);
let nibble = _mm256_and_si256(chunk, mask_0f);
let is_lower = _mm256_cmpgt_epi8(chunk, _mm256_set1_epi8(b'`' as i8)); let is_upper = _mm256_and_si256(
_mm256_cmpgt_epi8(chunk, _mm256_set1_epi8(b'@' as i8)), _mm256_cmpgt_epi8(_mm256_set1_epi8(b'G' as i8), chunk), );
let alpha_value = _mm256_add_epi8(nibble, _mm256_set1_epi8(9));
let is_alpha = _mm256_or_si256(is_lower, is_upper);
let values = _mm256_blendv_epi8(alpha_value, nibble, is_digit);
let raw: [u8; 32] = std::mem::transmute(values);
for j in 0..16 {
let high = raw[j * 2];
let low = raw[j * 2 + 1];
if high > 15 || low > 15 {
return decode_hex_scalar(bytes);
}
result.push((high << 4) | low);
}
}
let remainder = &bytes[chunks * 32..];
if !remainder.is_empty() {
if let Some(mut rest) = decode_hex_scalar(remainder) {
result.append(&mut rest);
} else {
return None;
}
}
Some(result)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn decode_hex_neon(bytes: &[u8]) -> Option<Vec<u8>> {
use std::arch::aarch64::*;
let mut result = Vec::with_capacity(bytes.len() / 2);
let chunks = bytes.len() / 16;
let ptr = bytes.as_ptr();
let nine = vdupq_n_u8(b'9');
let lower_a = vdupq_n_u8(b'a');
let upper_a = vdupq_n_u8(b'A');
let zero = vdupq_n_u8(b'0');
let mask_0f = vdupq_n_u8(0x0F);
for i in 0..chunks {
let chunk = vld1q_u8(ptr.add(i * 16));
let nibble = vandq_u8(chunk, mask_0f);
let ge_zero = vcgeq_u8(chunk, zero);
let le_nine = vcleq_u8(chunk, nine);
let is_digit = vandq_u8(ge_zero, le_nine);
let alpha_value = vaddq_u8(nibble, vdupq_n_u8(9));
let values = vbslq_u8(is_digit, nibble, alpha_value);
let raw: [u8; 16] = std::mem::transmute(values);
for j in 0..8 {
let high = raw[j * 2];
let low = raw[j * 2 + 1];
if high > 15 || low > 15 {
return decode_hex_scalar(bytes);
}
result.push((high << 4) | low);
}
}
let remainder = &bytes[chunks * 16..];
if !remainder.is_empty() {
if let Some(mut rest) = decode_hex_scalar(remainder) {
result.append(&mut rest);
} else {
return None;
}
}
Some(result)
}
#[inline]
fn try_decode_base64(bytes: &[u8]) -> Option<Vec<u8>> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && bytes.len() >= 32 {
return unsafe { decode_base64_avx2(bytes) };
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") && bytes.len() >= 16 {
return unsafe { decode_base64_neon(bytes) };
}
}
decode_base64_scalar(bytes)
}
fn decode_base64_scalar(bytes: &[u8]) -> Option<Vec<u8>> {
const DECODE_TABLE: [i8; 256] = {
let mut table = [-1i8; 256];
let mut i = 0u8;
while i < 26 {
table[(b'A' + i) as usize] = i as i8;
table[(b'a' + i) as usize] = (i + 26) as i8;
i += 1;
}
let mut i = 0u8;
while i < 10 {
table[(b'0' + i) as usize] = (i + 52) as i8;
i += 1;
}
table[b'+' as usize] = 62;
table[b'/' as usize] = 63;
table[b'-' as usize] = 62; table[b'_' as usize] = 63; table
};
let input = bytes
.iter()
.copied()
.filter(|&b| b != b'=')
.collect::<Vec<_>>();
if input.is_empty() {
return Some(Vec::new());
}
let mut result = Vec::with_capacity(input.len() * 3 / 4);
for chunk in input.chunks(4) {
let mut accum = 0u32;
let mut valid_chars = 0;
for &byte in chunk {
let val = DECODE_TABLE[byte as usize];
if val < 0 {
return None;
}
accum = (accum << 6) | (val as u32);
valid_chars += 1;
}
accum <<= (4 - valid_chars) * 6;
match valid_chars {
4 => {
result.push((accum >> 16) as u8);
result.push((accum >> 8) as u8);
result.push(accum as u8);
}
3 => {
result.push((accum >> 16) as u8);
result.push((accum >> 8) as u8);
}
2 => {
result.push((accum >> 16) as u8);
}
_ => return None,
}
}
Some(result)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn decode_base64_avx2(bytes: &[u8]) -> Option<Vec<u8>> {
use std::arch::x86_64::*;
let input: Vec<u8> = bytes.iter().copied().filter(|&b| b != b'=').collect();
if input.len() < 32 {
return decode_base64_scalar(bytes);
}
let mut result = Vec::with_capacity(input.len() * 3 / 4);
let chunks = input.len() / 32;
let ptr = input.as_ptr();
for i in 0..chunks {
let chunk = _mm256_loadu_si256(ptr.add(i * 32) as *const __m256i);
let raw: [u8; 32] = std::mem::transmute(chunk);
let mut decoded = [0u8; 32];
for (j, &byte) in raw.iter().enumerate() {
let val = match byte {
b'A'..=b'Z' => byte - b'A',
b'a'..=b'z' => byte - b'a' + 26,
b'0'..=b'9' => byte - b'0' + 52,
b'+' | b'-' => 62,
b'/' | b'_' => 63,
_ => return decode_base64_scalar(bytes),
};
decoded[j] = val;
}
for j in 0..8 {
let d = &decoded[j * 4..j * 4 + 4];
let accum = ((d[0] as u32) << 18)
| ((d[1] as u32) << 12)
| ((d[2] as u32) << 6)
| (d[3] as u32);
result.push((accum >> 16) as u8);
result.push((accum >> 8) as u8);
result.push(accum as u8);
}
}
let remainder = &input[chunks * 32..];
if !remainder.is_empty() {
if let Some(mut rest) = decode_base64_scalar(remainder) {
result.append(&mut rest);
} else {
return None;
}
}
Some(result)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn decode_base64_neon(bytes: &[u8]) -> Option<Vec<u8>> {
use std::arch::aarch64::*;
let input: Vec<u8> = bytes.iter().copied().filter(|&b| b != b'=').collect();
if input.len() < 16 {
return decode_base64_scalar(bytes);
}
let mut result = Vec::with_capacity(input.len() * 3 / 4);
let chunks = input.len() / 16;
let ptr = input.as_ptr();
for i in 0..chunks {
let chunk = vld1q_u8(ptr.add(i * 16));
let raw: [u8; 16] = std::mem::transmute(chunk);
let mut decoded = [0u8; 16];
for (j, &byte) in raw.iter().enumerate() {
let val = match byte {
b'A'..=b'Z' => byte - b'A',
b'a'..=b'z' => byte - b'a' + 26,
b'0'..=b'9' => byte - b'0' + 52,
b'+' | b'-' => 62,
b'/' | b'_' => 63,
_ => return decode_base64_scalar(bytes),
};
decoded[j] = val;
}
for j in 0..4 {
let d = &decoded[j * 4..j * 4 + 4];
let accum = ((d[0] as u32) << 18)
| ((d[1] as u32) << 12)
| ((d[2] as u32) << 6)
| (d[3] as u32);
result.push((accum >> 16) as u8);
result.push((accum >> 8) as u8);
result.push(accum as u8);
}
}
let remainder = &input[chunks * 16..];
if !remainder.is_empty() {
if let Some(mut rest) = decode_base64_scalar(remainder) {
result.append(&mut rest);
} else {
return None;
}
}
Some(result)
}
#[inline]
pub fn entropy_with_decode(s: &str) -> (f64, EncodingType) {
let decode_result = speculative_decode(s);
match decode_result.decoded {
Some(decoded) if !decoded.is_empty() => {
let entropy = super::calculate_entropy(&decoded);
(entropy, decode_result.encoding)
}
_ => {
let entropy = super::calculate_entropy(s.as_bytes());
(entropy, EncodingType::None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hex_decode() {
let hex = "48656c6c6f"; let result = speculative_decode(hex);
assert_eq!(result.encoding, EncodingType::Hex);
assert_eq!(result.decoded, Some(b"Hello".to_vec()));
}
#[test]
fn test_base64_decode() {
let b64 = "SGVsbG8gV29ybGQ="; let result = speculative_decode(b64);
assert_eq!(result.encoding, EncodingType::Base64);
assert_eq!(result.decoded, Some(b"Hello World".to_vec()));
}
#[test]
fn test_not_encoded() {
let plain = "just_a_normal_string";
let result = speculative_decode(plain);
assert_eq!(result.encoding, EncodingType::None);
assert_eq!(result.decoded, None);
}
#[test]
fn test_entropy_with_decode() {
let encoded = "QUFBQUFBQUFBQQ==";
let (entropy, encoding) = entropy_with_decode(encoded);
assert_eq!(encoding, EncodingType::Base64);
assert!(entropy < 1.0, "Decoded entropy should be low: {}", entropy);
let raw_entropy = super::super::calculate_entropy(encoded.as_bytes());
assert!(
raw_entropy > entropy,
"Raw entropy {} should be higher than decoded {}",
raw_entropy,
entropy
);
}
#[test]
fn test_hex_entropy_reduction() {
let hex = "41414141";
let (entropy, encoding) = entropy_with_decode(hex);
assert_eq!(encoding, EncodingType::Hex);
assert_eq!(entropy, 0.0);
}
#[test]
fn test_long_hex_string() {
let hex = "48656c6c6f20576f726c6421".repeat(3); let result = speculative_decode(&hex);
assert_eq!(result.encoding, EncodingType::Hex);
assert!(result.decoded.is_some());
}
}