use super::super::common;
use crate::core::dictionary::Dictionary;
use crate::simd::variants::Base32Variant;
pub fn encode(data: &[u8], dictionary: &Dictionary, variant: Base32Variant) -> Option<String> {
let output_len = data.len().div_ceil(5) * 8;
let mut result = String::with_capacity(output_len);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
encode_avx2_impl(data, dictionary, variant, &mut result);
}
} else {
unsafe {
encode_ssse3_impl(data, dictionary, variant, &mut result);
}
}
}
#[cfg(not(target_arch = "x86_64"))]
{
encode_scalar_remainder(data, dictionary, &mut result);
}
Some(result)
}
fn validate_base32_padding(input: &str) -> Option<&str> {
let padding_count = input.bytes().rev().take_while(|&b| b == b'=').count();
let data_len = input.len() - padding_count;
if padding_count == 0 {
return match data_len % 8 {
0 | 2 | 4 | 5 | 7 => Some(input),
_ => None,
};
}
if !input.len().is_multiple_of(8) {
return None;
}
let expected_padding = match data_len % 8 {
0 => 0,
2 => 6,
4 => 4,
5 => 3,
7 => 1,
_ => return None,
};
if padding_count == expected_padding {
Some(&input[..data_len])
} else {
None
}
}
pub fn decode(encoded: &str, variant: Base32Variant) -> Option<Vec<u8>> {
let input_no_padding = validate_base32_padding(encoded)?;
let encoded_bytes = input_no_padding.as_bytes();
let output_len = (input_no_padding.len() / 8) * 5
+ match input_no_padding.len() % 8 {
0 => 0,
2 => 1,
4 => 2,
5 => 3,
7 => 4,
_ => return None, };
let mut result = Vec::with_capacity(output_len);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
if !unsafe { decode_avx2_impl(encoded_bytes, variant, &mut result) } {
return None;
}
} else if !unsafe { decode_ssse3_impl(encoded_bytes, variant, &mut result) } {
return None;
}
}
#[cfg(not(target_arch = "x86_64"))]
{
if !unsafe { decode_ssse3_impl(encoded_bytes, variant, &mut result) } {
return None;
}
}
Some(result)
}
fn encode_scalar_remainder(data: &[u8], dictionary: &Dictionary, result: &mut String) {
common::encode_scalar_chunked(data, dictionary, result);
let chars_produced = result.len();
let padding_needed = (8 - (chars_produced % 8)) % 8;
if let Some(pad_char) = dictionary.padding() {
for _ in 0..padding_needed {
result.push(pad_char);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn encode_avx2_impl(
data: &[u8],
dictionary: &Dictionary,
variant: Base32Variant,
result: &mut String,
) {
use std::arch::x86_64::*;
const BLOCK_SIZE: usize = 20;
if data.len() < 32 {
unsafe {
encode_ssse3_impl(data, dictionary, variant, result);
}
return;
}
let safe_len = if data.len() >= 12 { data.len() - 12 } else { 0 };
let (num_rounds, simd_bytes) = common::calculate_blocks(safe_len, BLOCK_SIZE);
let mut offset = 0;
for _ in 0..num_rounds {
let (input_lo, input_hi) = unsafe {
(
_mm_loadu_si128(data.as_ptr().add(offset) as *const __m128i),
_mm_loadu_si128(data.as_ptr().add(offset + 10) as *const __m128i),
)
};
let input_256 = _mm256_set_m128i(input_hi, input_lo);
let indices = unsafe { extract_5bit_indices_avx2(input_256) };
let encoded = unsafe { translate_encode_avx2(indices, variant) };
let mut output_buf = [0u8; 32];
unsafe {
_mm256_storeu_si256(output_buf.as_mut_ptr() as *mut __m256i, encoded);
}
for &byte in &output_buf {
result.push(byte as char);
}
offset += BLOCK_SIZE;
}
if simd_bytes < data.len() {
unsafe {
encode_ssse3_impl(&data[simd_bytes..], dictionary, variant, result);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn extract_5bit_indices_avx2(
input: std::arch::x86_64::__m256i,
) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
let lane_lo = _mm256_castsi256_si128(input);
let lane_hi = _mm256_extracti128_si256(input, 1);
let indices_lo = unsafe { unpack_5bit_simple(lane_lo) };
let indices_hi = unsafe { unpack_5bit_simple(lane_hi) };
_mm256_set_m128i(indices_hi, indices_lo)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn translate_encode_avx2(
indices: std::arch::x86_64::__m256i,
variant: Base32Variant,
) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
match variant {
Base32Variant::Rfc4648 => {
let ge_26 = _mm256_cmpgt_epi8(indices, _mm256_set1_epi8(25));
let base = _mm256_set1_epi8(b'A' as i8);
let adjustment = _mm256_and_si256(ge_26, _mm256_set1_epi8(-41));
_mm256_add_epi8(_mm256_add_epi8(indices, base), adjustment)
}
Base32Variant::Rfc4648Hex => {
let ge_10 = _mm256_cmpgt_epi8(indices, _mm256_set1_epi8(9));
let base = _mm256_set1_epi8(b'0' as i8);
let adjustment = _mm256_and_si256(ge_10, _mm256_set1_epi8(7));
_mm256_add_epi8(_mm256_add_epi8(indices, base), adjustment)
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn encode_ssse3_impl(
data: &[u8],
dictionary: &Dictionary,
variant: Base32Variant,
result: &mut String,
) {
use std::arch::x86_64::*;
const BLOCK_SIZE: usize = 10;
if data.len() < 16 {
encode_scalar_remainder(data, dictionary, result);
return;
}
let safe_len = if data.len() >= 6 { data.len() - 6 } else { 0 };
let (num_rounds, simd_bytes) = common::calculate_blocks(safe_len, BLOCK_SIZE);
let mut offset = 0;
for _ in 0..num_rounds {
let input_vec = unsafe { _mm_loadu_si128(data.as_ptr().add(offset) as *const __m128i) };
let indices = unsafe { extract_5bit_indices(input_vec) };
let encoded = unsafe { translate_encode(indices, variant) };
let mut output_buf = [0u8; 16];
unsafe {
_mm_storeu_si128(output_buf.as_mut_ptr() as *mut __m128i, encoded);
}
for &byte in &output_buf {
result.push(byte as char);
}
offset += BLOCK_SIZE;
}
if simd_bytes < data.len() {
encode_scalar_remainder(&data[simd_bytes..], dictionary, result);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn extract_5bit_indices(input: std::arch::x86_64::__m128i) -> std::arch::x86_64::__m128i {
unsafe { unpack_5bit_simple(input) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn unpack_5bit_simple(input: std::arch::x86_64::__m128i) -> std::arch::x86_64::__m128i {
use std::arch::x86_64::*;
let mut buf = [0u8; 16];
unsafe {
_mm_storeu_si128(buf.as_mut_ptr() as *mut __m128i, input);
}
let mut indices = [0u8; 16];
indices[0] = buf[0] >> 3;
indices[1] = ((buf[0] & 0x07) << 2) | (buf[1] >> 6);
indices[2] = (buf[1] >> 1) & 0x1F;
indices[3] = ((buf[1] & 0x01) << 4) | (buf[2] >> 4);
indices[4] = ((buf[2] & 0x0F) << 1) | (buf[3] >> 7);
indices[5] = (buf[3] >> 2) & 0x1F;
indices[6] = ((buf[3] & 0x03) << 3) | (buf[4] >> 5);
indices[7] = buf[4] & 0x1F;
indices[8] = buf[5] >> 3;
indices[9] = ((buf[5] & 0x07) << 2) | (buf[6] >> 6);
indices[10] = (buf[6] >> 1) & 0x1F;
indices[11] = ((buf[6] & 0x01) << 4) | (buf[7] >> 4);
indices[12] = ((buf[7] & 0x0F) << 1) | (buf[8] >> 7);
indices[13] = (buf[8] >> 2) & 0x1F;
indices[14] = ((buf[8] & 0x03) << 3) | (buf[9] >> 5);
indices[15] = buf[9] & 0x1F;
unsafe { _mm_loadu_si128(indices.as_ptr() as *const __m128i) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn translate_encode(
indices: std::arch::x86_64::__m128i,
variant: Base32Variant,
) -> std::arch::x86_64::__m128i {
use std::arch::x86_64::*;
match variant {
Base32Variant::Rfc4648 => {
let ge_26 = _mm_cmpgt_epi8(indices, _mm_set1_epi8(25));
let base = _mm_set1_epi8(b'A' as i8);
let adjustment = _mm_and_si128(ge_26, _mm_set1_epi8(-41));
_mm_add_epi8(_mm_add_epi8(indices, base), adjustment)
}
Base32Variant::Rfc4648Hex => {
let ge_10 = _mm_cmpgt_epi8(indices, _mm_set1_epi8(9));
let base = _mm_set1_epi8(b'0' as i8);
let adjustment = _mm_and_si128(ge_10, _mm_set1_epi8(7));
_mm_add_epi8(_mm_add_epi8(indices, base), adjustment)
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn decode_avx2_impl(encoded: &[u8], variant: Base32Variant, result: &mut Vec<u8>) -> bool {
use std::arch::x86_64::*;
const INPUT_BLOCK_SIZE: usize = 32;
if encoded.len() < 32 {
return unsafe { decode_ssse3_impl(encoded, variant, result) };
}
let (delta_check_128, delta_rebase_128) = unsafe { get_decode_delta_tables(variant) };
let delta_check = _mm256_broadcastsi128_si256(delta_check_128);
let delta_rebase = _mm256_broadcastsi128_si256(delta_rebase_128);
let (num_rounds, simd_bytes) = common::calculate_blocks(encoded.len(), INPUT_BLOCK_SIZE);
for round in 0..num_rounds {
let offset = round * INPUT_BLOCK_SIZE;
let input_vec =
unsafe { _mm256_loadu_si256(encoded.as_ptr().add(offset) as *const __m256i) };
let hash_key = _mm256_and_si256(_mm256_srli_epi32(input_vec, 4), _mm256_set1_epi8(0x0F));
let check = _mm256_add_epi8(_mm256_shuffle_epi8(delta_check, hash_key), input_vec);
let invalid_mask = _mm256_cmpgt_epi8(check, _mm256_set1_epi8(0x1F));
if _mm256_movemask_epi8(invalid_mask) != 0 {
return false; }
let indices = _mm256_add_epi8(input_vec, _mm256_shuffle_epi8(delta_rebase, hash_key));
let decoded = unsafe { pack_5bit_to_8bit_avx2(indices) };
let lane0 = _mm256_castsi256_si128(decoded);
let lane1 = _mm256_extracti128_si256(decoded, 1);
let mut buf0 = [0u8; 16];
let mut buf1 = [0u8; 16];
unsafe {
_mm_storeu_si128(buf0.as_mut_ptr() as *mut __m128i, lane0);
_mm_storeu_si128(buf1.as_mut_ptr() as *mut __m128i, lane1);
}
result.extend_from_slice(&buf0[0..10]);
result.extend_from_slice(&buf1[0..10]);
}
if simd_bytes < encoded.len() {
let remainder = &encoded[simd_bytes..];
if !unsafe { decode_ssse3_impl(remainder, variant, result) } {
return false;
}
}
true
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn pack_5bit_to_8bit_avx2(
indices: std::arch::x86_64::__m256i,
) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
let lane_lo = _mm256_castsi256_si128(indices);
let lane_hi = _mm256_extracti128_si256(indices, 1);
let packed_lo = unsafe { pack_5bit_to_8bit(lane_lo) };
let packed_hi = unsafe { pack_5bit_to_8bit(lane_hi) };
_mm256_set_m128i(packed_hi, packed_lo)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn decode_ssse3_impl(encoded: &[u8], variant: Base32Variant, result: &mut Vec<u8>) -> bool {
use std::arch::x86_64::*;
const INPUT_BLOCK_SIZE: usize = 16;
let (delta_check, delta_rebase) = unsafe { get_decode_delta_tables(variant) };
let (num_rounds, simd_bytes) = common::calculate_blocks(encoded.len(), INPUT_BLOCK_SIZE);
for round in 0..num_rounds {
let offset = round * INPUT_BLOCK_SIZE;
let input_vec = unsafe { _mm_loadu_si128(encoded.as_ptr().add(offset) as *const __m128i) };
let hash_key = _mm_and_si128(_mm_srli_epi32(input_vec, 4), _mm_set1_epi8(0x0F));
let check = _mm_add_epi8(_mm_shuffle_epi8(delta_check, hash_key), input_vec);
let invalid_mask = _mm_cmpgt_epi8(check, _mm_set1_epi8(0x1F));
if _mm_movemask_epi8(invalid_mask) != 0 {
return false; }
let indices = _mm_add_epi8(input_vec, _mm_shuffle_epi8(delta_rebase, hash_key));
let decoded = unsafe { pack_5bit_to_8bit(indices) };
let mut output_buf = [0u8; 16];
unsafe {
_mm_storeu_si128(output_buf.as_mut_ptr() as *mut __m128i, decoded);
}
result.extend_from_slice(&output_buf[0..10]);
}
if simd_bytes < encoded.len() {
let remainder = &encoded[simd_bytes..];
if !decode_scalar_remainder(
remainder,
&mut |c| match variant {
Base32Variant::Rfc4648 => match c {
b'A'..=b'Z' => Some(c - b'A'),
b'2'..=b'7' => Some(c - b'2' + 26),
_ => None,
},
Base32Variant::Rfc4648Hex => match c {
b'0'..=b'9' => Some(c - b'0'),
b'A'..=b'V' => Some(c - b'A' + 10),
_ => None,
},
},
result,
) {
return false;
}
}
true
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn get_decode_delta_tables(
variant: Base32Variant,
) -> (std::arch::x86_64::__m128i, std::arch::x86_64::__m128i) {
use std::arch::x86_64::*;
match variant {
Base32Variant::Rfc4648 => {
let delta_check = _mm_setr_epi8(
0x7F,
0x7F,
0x7F, (0x1F - 0x37) as i8, (0x1F - 0x4F) as i8, (0x1F - 0x5A) as i8, 0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F, );
let delta_rebase = _mm_setr_epi8(
0,
0,
0, (26i16 - b'2' as i16) as i8, (0i16 - b'A' as i16) as i8, (0i16 - b'A' as i16) as i8, 0,
0,
0,
0,
0,
0,
0,
0,
0,
0, );
(delta_check, delta_rebase)
}
Base32Variant::Rfc4648Hex => {
let delta_check = _mm_setr_epi8(
0x7F,
0x7F,
0x7F, (0x1F - 0x39) as i8, (0x1F - 0x4F) as i8, (0x1F - 0x56) as i8, 0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F,
0x7F, );
let delta_rebase = _mm_setr_epi8(
0,
0,
0, (0i16 - b'0' as i16) as i8, (10i16 - b'A' as i16) as i8, (10i16 - b'A' as i16) as i8, 0,
0,
0,
0,
0,
0,
0,
0,
0,
0, );
(delta_check, delta_rebase)
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn pack_5bit_to_8bit(indices: std::arch::x86_64::__m128i) -> std::arch::x86_64::__m128i {
use std::arch::x86_64::*;
let merged = _mm_maddubs_epi16(indices, _mm_set1_epi32(0x01200120u32 as i32));
let combined = _mm_madd_epi16(
merged,
_mm_set_epi32(
0x00010400, 0x00104000, 0x00010400, 0x00104000, ),
);
let shifted = _mm_srli_epi64(combined, 48);
let packed = _mm_or_si128(combined, shifted);
_mm_shuffle_epi8(
packed,
_mm_setr_epi8(
2, 1, 0, 5, 4, 10, 9, 8, 13, 12, 0, 0, 0, 0, 0, 0, ),
)
}
fn decode_scalar_remainder(
data: &[u8],
char_to_index: &mut dyn FnMut(u8) -> Option<u8>,
result: &mut Vec<u8>,
) -> bool {
common::decode_scalar_chunked(data, char_to_index, result, 5)
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::core::config::EncodingMode;
use crate::core::dictionary::Dictionary;
fn make_base32_dict() -> Dictionary {
let chars: Vec<char> = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567".chars().collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, Some('=')).unwrap()
}
fn make_base32_hex_dict() -> Dictionary {
let chars: Vec<char> = "0123456789ABCDEFGHIJKLMNOPQRSTUV".chars().collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, Some('=')).unwrap()
}
#[test]
fn test_encode_known_values() {
let dictionary = make_base32_dict();
let test_cases = [
(b"".as_slice(), ""),
(b"f", "MY======"),
(b"fo", "MZXQ===="),
(b"foo", "MZXW6==="),
(b"foob", "MZXW6YQ="),
(b"fooba", "MZXW6YTB"),
(b"foobar", "MZXW6YTBOI======"),
];
for (input, expected) in test_cases {
if let Some(simd_result) = encode(input, &dictionary, Base32Variant::Rfc4648) {
assert_eq!(simd_result, expected, "Failed for input: {:?}", input);
}
}
}
#[test]
fn test_encode_hex_variant() {
let dictionary = make_base32_hex_dict();
let test_cases = [
(b"".as_slice(), ""),
(b"f", "CO======"),
(b"fo", "CPNG===="),
(b"foo", "CPNMU==="),
];
for (input, expected) in test_cases {
if let Some(simd_result) = encode(input, &dictionary, Base32Variant::Rfc4648Hex) {
assert_eq!(simd_result, expected, "Failed for input: {:?}", input);
}
}
}
#[test]
fn test_decode_round_trip() {
let dictionary = make_base32_dict();
for len in 0..100 {
let original: Vec<u8> = (0..len).map(|i| (i * 7) as u8).collect();
if let Some(encoded) = encode(&original, &dictionary, Base32Variant::Rfc4648)
&& let Some(decoded) = decode(&encoded, Base32Variant::Rfc4648)
{
assert_eq!(decoded, original, "Round-trip failed at length {}", len);
}
}
}
#[test]
fn test_decode_hex_round_trip() {
let dictionary = make_base32_hex_dict();
for len in 0..100 {
let original: Vec<u8> = (0..len).map(|i| (i * 7) as u8).collect();
if let Some(encoded) = encode(&original, &dictionary, Base32Variant::Rfc4648Hex)
&& let Some(decoded) = decode(&encoded, Base32Variant::Rfc4648Hex)
{
assert_eq!(decoded, original, "Round-trip failed at length {}", len);
}
}
}
#[test]
fn test_avx2_large_input() {
let dictionary = make_base32_dict();
let test_data: Vec<u8> = (0..40).collect();
if let Some(simd_result) = encode(&test_data, &dictionary, Base32Variant::Rfc4648) {
if let Some(decoded) = decode(&simd_result, Base32Variant::Rfc4648) {
assert_eq!(decoded, test_data, "AVX2 round-trip failed");
}
}
}
#[test]
fn test_avx2_decode_large() {
let dictionary = make_base32_dict();
let test_data: Vec<u8> = (0..40).map(|i| (i * 3) as u8).collect();
if let Some(encoded) = encode(&test_data, &dictionary, Base32Variant::Rfc4648) {
if let Some(decoded) = decode(&encoded, Base32Variant::Rfc4648) {
assert_eq!(decoded, test_data, "AVX2 decode failed");
}
}
}
#[test]
fn test_padding_validation_correct() {
assert!(decode("MY======", Base32Variant::Rfc4648).is_some()); assert!(decode("MZXQ====", Base32Variant::Rfc4648).is_some()); assert!(decode("MZXW6===", Base32Variant::Rfc4648).is_some()); assert!(decode("MZXW6YQ=", Base32Variant::Rfc4648).is_some()); assert!(decode("MZXW6YTB", Base32Variant::Rfc4648).is_some());
assert!(decode("MY", Base32Variant::Rfc4648).is_some()); assert!(decode("MZXQ", Base32Variant::Rfc4648).is_some()); }
#[test]
fn test_padding_validation_incorrect() {
assert!(decode("MY=====", Base32Variant::Rfc4648).is_none()); assert!(decode("MY=======", Base32Variant::Rfc4648).is_none()); assert!(decode("MZXQ===", Base32Variant::Rfc4648).is_none()); assert!(decode("MZXW6==", Base32Variant::Rfc4648).is_none()); assert!(decode("MZXW6YQ==", Base32Variant::Rfc4648).is_none());
assert!(decode("MY=", Base32Variant::Rfc4648).is_none());
assert!(decode("M", Base32Variant::Rfc4648).is_none()); assert!(decode("MYX", Base32Variant::Rfc4648).is_none()); assert!(decode("MZXW6Y", Base32Variant::Rfc4648).is_none()); }
}