use super::super::common;
use crate::core::dictionary::Dictionary;
use crate::simd::variants::DictionaryVariant;
pub fn encode(data: &[u8], dictionary: &Dictionary, variant: DictionaryVariant) -> Option<String> {
let output_len = data.len().div_ceil(3) * 4;
let mut result = String::with_capacity(output_len);
#[cfg(target_arch = "x86_64")]
unsafe {
if is_x86_feature_detected!("avx2") {
encode_avx2_impl(data, dictionary, variant, &mut result);
} else {
encode_ssse3_impl(data, dictionary, variant, &mut result);
}
}
#[cfg(not(target_arch = "x86_64"))]
unsafe {
encode_ssse3_impl(data, dictionary, variant, &mut result);
}
Some(result)
}
pub fn decode(encoded: &str, variant: DictionaryVariant) -> Option<Vec<u8>> {
let encoded_bytes = encoded.as_bytes();
let input_no_padding = encoded.trim_end_matches('=');
let output_len = (input_no_padding.len() / 4) * 3
+ match input_no_padding.len() % 4 {
0 => 0,
2 => 1,
3 => 2,
_ => return None, };
let mut result = Vec::with_capacity(output_len);
#[cfg(target_arch = "x86_64")]
unsafe {
if is_x86_feature_detected!("avx2") {
if !decode_avx2_impl(encoded_bytes, variant, &mut result) {
return None;
}
} else if !decode_ssse3_impl(encoded_bytes, variant, &mut result) {
return None;
}
}
#[cfg(not(target_arch = "x86_64"))]
unsafe {
if !decode_ssse3_impl(encoded_bytes, variant, &mut result) {
return None;
}
}
Some(result)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn encode_avx2_impl(
data: &[u8],
dictionary: &Dictionary,
variant: DictionaryVariant,
result: &mut String,
) {
use std::arch::x86_64::*;
const BLOCK_SIZE: usize = 24;
if data.len() < 28 {
unsafe {
encode_ssse3_impl(data, dictionary, variant, result);
}
return;
}
let safe_len = if data.len() >= 8 { data.len() - 8 } else { 0 };
let (num_rounds, simd_bytes) = common::calculate_blocks(safe_len, BLOCK_SIZE);
let mut offset = 0;
for _ in 0..num_rounds {
let encoded = unsafe {
let input_lo = _mm_loadu_si128(data.as_ptr().add(offset) as *const __m128i);
let input_hi = _mm_loadu_si128(data.as_ptr().add(offset + 12) as *const __m128i);
let input_256 = _mm256_set_m128i(input_hi, input_lo);
let reshuffled = reshuffle_avx2(input_256);
translate_avx2(reshuffled, variant)
};
let mut output_buf = [0u8; 32];
unsafe {
_mm256_storeu_si256(output_buf.as_mut_ptr() as *mut __m256i, encoded);
}
for &byte in &output_buf {
result.push(byte as char);
}
offset += BLOCK_SIZE;
}
if simd_bytes < data.len() {
unsafe {
encode_ssse3_impl(&data[simd_bytes..], dictionary, variant, result);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn reshuffle_avx2(input: std::arch::x86_64::__m256i) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
let shuffle_mask = _mm256_set_epi8(
10, 11, 9, 10, 7, 8, 6, 7, 4, 5, 3, 4, 1, 2, 0, 1, 10, 11, 9, 10, 7, 8, 6, 7, 4, 5, 3, 4, 1, 2, 0, 1, );
let shuffled = _mm256_shuffle_epi8(input, shuffle_mask);
let t0 = _mm256_and_si256(shuffled, _mm256_set1_epi32(0x0FC0FC00_u32 as i32));
let t1 = _mm256_mulhi_epu16(t0, _mm256_set1_epi32(0x04000040_u32 as i32));
let t2 = _mm256_and_si256(shuffled, _mm256_set1_epi32(0x003F03F0_u32 as i32));
let t3 = _mm256_mullo_epi16(t2, _mm256_set1_epi32(0x01000010_u32 as i32));
_mm256_or_si256(t1, t3)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn translate_avx2(
indices: std::arch::x86_64::__m256i,
variant: DictionaryVariant,
) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
let lut = match variant {
DictionaryVariant::Base64Standard => _mm256_setr_epi8(
65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -19, -16, 0, 0, 65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -19, -16, 0, 0, ),
DictionaryVariant::Base64Url => _mm256_setr_epi8(
65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -17, 32, 0, 0, 65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -17, 32, 0, 0, ),
};
let mut lut_indices = _mm256_subs_epu8(indices, _mm256_set1_epi8(51));
let mask = _mm256_cmpgt_epi8(indices, _mm256_set1_epi8(25));
lut_indices = _mm256_sub_epi8(lut_indices, mask);
let offsets = _mm256_shuffle_epi8(lut, lut_indices);
_mm256_add_epi8(indices, offsets)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn decode_avx2_impl(
encoded: &[u8],
variant: DictionaryVariant,
result: &mut Vec<u8>,
) -> bool {
use std::arch::x86_64::*;
const INPUT_BLOCK_SIZE: usize = 32;
let input_no_padding = if let Some(last_non_pad) = encoded.iter().rposition(|&b| b != b'=') {
&encoded[..=last_non_pad]
} else {
encoded
};
if input_no_padding.len() < 32 {
return unsafe { decode_ssse3_impl(input_no_padding, variant, result) };
}
let (lut_lo, lut_hi, lut_roll) = unsafe {
let (lut_lo_128, lut_hi_128, lut_roll_128) = get_decode_luts(variant);
(
_mm256_broadcastsi128_si256(lut_lo_128),
_mm256_broadcastsi128_si256(lut_hi_128),
_mm256_broadcastsi128_si256(lut_roll_128),
)
};
let (num_rounds, simd_bytes) =
common::calculate_blocks(input_no_padding.len(), INPUT_BLOCK_SIZE);
for round in 0..num_rounds {
let offset = round * INPUT_BLOCK_SIZE;
let input_vec =
unsafe { _mm256_loadu_si256(input_no_padding.as_ptr().add(offset) as *const __m256i) };
if !unsafe { validate_avx2(input_vec, lut_lo, lut_hi) } {
return false; }
let decoded = unsafe {
let indices = translate_decode_avx2(input_vec, lut_hi, lut_roll);
reshuffle_decode_avx2(indices)
};
let (buf0, buf1) = unsafe {
let lane0 = _mm256_castsi256_si128(decoded);
let lane1 = _mm256_extracti128_si256(decoded, 1);
let mut buf0 = [0u8; 16];
let mut buf1 = [0u8; 16];
_mm_storeu_si128(buf0.as_mut_ptr() as *mut __m128i, lane0);
_mm_storeu_si128(buf1.as_mut_ptr() as *mut __m128i, lane1);
(buf0, buf1)
};
result.extend_from_slice(&buf0[0..12]);
result.extend_from_slice(&buf1[0..12]);
}
if simd_bytes < input_no_padding.len() {
let remainder = &input_no_padding[simd_bytes..];
if !unsafe { decode_ssse3_impl(remainder, variant, result) } {
return false;
}
}
true
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn validate_avx2(
input: std::arch::x86_64::__m256i,
lut_lo: std::arch::x86_64::__m256i,
lut_hi: std::arch::x86_64::__m256i,
) -> bool {
use std::arch::x86_64::*;
let lo_nibbles = _mm256_and_si256(input, _mm256_set1_epi8(0x0F));
let hi_nibbles = _mm256_and_si256(_mm256_srli_epi32(input, 4), _mm256_set1_epi8(0x0F));
let lo_lookup = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
let hi_lookup = _mm256_shuffle_epi8(lut_hi, hi_nibbles);
let validation = _mm256_and_si256(lo_lookup, hi_lookup);
_mm256_movemask_epi8(validation) == 0
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn translate_decode_avx2(
input: std::arch::x86_64::__m256i,
_lut_hi: std::arch::x86_64::__m256i,
lut_roll: std::arch::x86_64::__m256i,
) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
let hi_nibbles = _mm256_and_si256(_mm256_srli_epi32(input, 4), _mm256_set1_epi8(0x0F));
let eq_2f = _mm256_cmpeq_epi8(input, _mm256_set1_epi8(0x2F));
let roll_index = _mm256_add_epi8(eq_2f, hi_nibbles);
let offsets = _mm256_shuffle_epi8(lut_roll, roll_index);
_mm256_add_epi8(input, offsets)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn reshuffle_decode_avx2(indices: std::arch::x86_64::__m256i) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
let merge_ab_and_bc = _mm256_maddubs_epi16(indices, _mm256_set1_epi32(0x01400140u32 as i32));
let final_32bit = _mm256_madd_epi16(merge_ab_and_bc, _mm256_set1_epi32(0x00011000u32 as i32));
_mm256_shuffle_epi8(
final_32bit,
_mm256_setr_epi8(
2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, -1, -1, -1, -1, 2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, -1, -1, -1, -1, ),
)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn encode_ssse3_impl(
data: &[u8],
dictionary: &Dictionary,
variant: DictionaryVariant,
result: &mut String,
) {
use std::arch::x86_64::*;
const BLOCK_SIZE: usize = 12;
if data.len() < 16 {
encode_scalar_remainder(data, dictionary, result);
return;
}
let safe_len = if data.len() >= 4 { data.len() - 4 } else { 0 };
let (num_rounds, simd_bytes) = common::calculate_blocks(safe_len, BLOCK_SIZE);
let mut offset = 0;
for _ in 0..num_rounds {
let encoded = unsafe {
let input_vec = _mm_loadu_si128(data.as_ptr().add(offset) as *const __m128i);
let reshuffled = reshuffle(input_vec);
translate(reshuffled, variant)
};
let mut output_buf = [0u8; 16];
unsafe {
_mm_storeu_si128(output_buf.as_mut_ptr() as *mut __m128i, encoded);
}
for &byte in &output_buf {
result.push(byte as char);
}
offset += BLOCK_SIZE;
}
if simd_bytes < data.len() {
encode_scalar_remainder(&data[simd_bytes..], dictionary, result);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn reshuffle(input: std::arch::x86_64::__m128i) -> std::arch::x86_64::__m128i {
use std::arch::x86_64::*;
let shuffled = _mm_shuffle_epi8(
input,
_mm_set_epi8(
10, 11, 9, 10, 7, 8, 6, 7, 4, 5, 3, 4, 1, 2, 0, 1, ),
);
let t0 = _mm_and_si128(shuffled, _mm_set1_epi32(0x0FC0FC00_u32 as i32));
let t1 = _mm_mulhi_epu16(t0, _mm_set1_epi32(0x04000040_u32 as i32));
let t2 = _mm_and_si128(shuffled, _mm_set1_epi32(0x003F03F0_u32 as i32));
let t3 = _mm_mullo_epi16(t2, _mm_set1_epi32(0x01000010_u32 as i32));
_mm_or_si128(t1, t3)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn translate(
indices: std::arch::x86_64::__m128i,
variant: DictionaryVariant,
) -> std::arch::x86_64::__m128i {
use std::arch::x86_64::*;
let lut = match variant {
DictionaryVariant::Base64Standard => _mm_setr_epi8(
65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-19, -16, 0, 0, ),
DictionaryVariant::Base64Url => _mm_setr_epi8(
65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
-17, 32, 0, 0, ),
};
let mut lut_indices = _mm_subs_epu8(indices, _mm_set1_epi8(51));
let mask = _mm_cmpgt_epi8(indices, _mm_set1_epi8(25));
lut_indices = _mm_sub_epi8(lut_indices, mask);
let offsets = _mm_shuffle_epi8(lut, lut_indices);
_mm_add_epi8(indices, offsets)
}
fn encode_scalar_remainder(data: &[u8], dictionary: &Dictionary, result: &mut String) {
common::encode_scalar_chunked(data, dictionary, result);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn decode_ssse3_impl(
encoded: &[u8],
variant: DictionaryVariant,
result: &mut Vec<u8>,
) -> bool {
use std::arch::x86_64::*;
const INPUT_BLOCK_SIZE: usize = 16;
const OUTPUT_BLOCK_SIZE: usize = 12;
let input_no_padding = if let Some(last_non_pad) = encoded.iter().rposition(|&b| b != b'=') {
&encoded[..=last_non_pad]
} else {
encoded
};
let (lut_lo, lut_hi, lut_roll) = unsafe { get_decode_luts(variant) };
let (num_rounds, simd_bytes) =
common::calculate_blocks(input_no_padding.len(), INPUT_BLOCK_SIZE);
for round in 0..num_rounds {
let offset = round * INPUT_BLOCK_SIZE;
let input_vec =
unsafe { _mm_loadu_si128(input_no_padding.as_ptr().add(offset) as *const __m128i) };
if !unsafe { validate(input_vec, lut_lo, lut_hi) } {
return false; }
let decoded = unsafe {
let indices = translate_decode(input_vec, lut_hi, lut_roll);
reshuffle_decode(indices)
};
let mut output_buf = [0u8; 16];
unsafe {
_mm_storeu_si128(output_buf.as_mut_ptr() as *mut __m128i, decoded);
}
result.extend_from_slice(&output_buf[0..OUTPUT_BLOCK_SIZE]);
}
if simd_bytes < input_no_padding.len() {
let remainder = &input_no_padding[simd_bytes..];
if !decode_scalar_remainder(
remainder,
&mut |c| match c {
b'A'..=b'Z' => Some(c - b'A'),
b'a'..=b'z' => Some(c - b'a' + 26),
b'0'..=b'9' => Some(c - b'0' + 52),
b'+' if matches!(variant, DictionaryVariant::Base64Standard) => Some(62),
b'/' if matches!(variant, DictionaryVariant::Base64Standard) => Some(63),
b'-' if matches!(variant, DictionaryVariant::Base64Url) => Some(62),
b'_' if matches!(variant, DictionaryVariant::Base64Url) => Some(63),
_ => None,
},
result,
) {
return false;
}
}
true
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn get_decode_luts(
variant: DictionaryVariant,
) -> (
std::arch::x86_64::__m128i,
std::arch::x86_64::__m128i,
std::arch::x86_64::__m128i,
) {
use std::arch::x86_64::*;
let lut_lo = _mm_setr_epi8(
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B,
0x1A,
);
let lut_hi = _mm_setr_epi8(
0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x10,
);
let lut_roll = match variant {
DictionaryVariant::Base64Standard => {
_mm_setr_epi8(0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0)
}
DictionaryVariant::Base64Url => {
_mm_setr_epi8(0, 17, -32, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0)
}
};
(lut_lo, lut_hi, lut_roll)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn validate(
input: std::arch::x86_64::__m128i,
lut_lo: std::arch::x86_64::__m128i,
lut_hi: std::arch::x86_64::__m128i,
) -> bool {
use std::arch::x86_64::*;
let lo_nibbles = _mm_and_si128(input, _mm_set1_epi8(0x0F));
let hi_nibbles = _mm_and_si128(_mm_srli_epi32(input, 4), _mm_set1_epi8(0x0F));
let lo_lookup = _mm_shuffle_epi8(lut_lo, lo_nibbles);
let hi_lookup = _mm_shuffle_epi8(lut_hi, hi_nibbles);
let validation = _mm_and_si128(lo_lookup, hi_lookup);
_mm_movemask_epi8(validation) == 0
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn translate_decode(
input: std::arch::x86_64::__m128i,
_lut_hi: std::arch::x86_64::__m128i,
lut_roll: std::arch::x86_64::__m128i,
) -> std::arch::x86_64::__m128i {
use std::arch::x86_64::*;
let hi_nibbles = _mm_and_si128(_mm_srli_epi32(input, 4), _mm_set1_epi8(0x0F));
let eq_2f = _mm_cmpeq_epi8(input, _mm_set1_epi8(0x2F));
let roll_index = _mm_add_epi8(eq_2f, hi_nibbles);
let offsets = _mm_shuffle_epi8(lut_roll, roll_index);
_mm_add_epi8(input, offsets)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn reshuffle_decode(indices: std::arch::x86_64::__m128i) -> std::arch::x86_64::__m128i {
use std::arch::x86_64::*;
let merge_ab_and_bc = _mm_maddubs_epi16(indices, _mm_set1_epi32(0x01400140u32 as i32));
let final_32bit = _mm_madd_epi16(merge_ab_and_bc, _mm_set1_epi32(0x00011000u32 as i32));
_mm_shuffle_epi8(
final_32bit,
_mm_setr_epi8(
2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, -1, -1, -1, -1, ),
)
}
fn decode_scalar_remainder(
data: &[u8],
char_to_index: &mut dyn FnMut(u8) -> Option<u8>,
result: &mut Vec<u8>,
) -> bool {
common::decode_scalar_chunked(data, char_to_index, result, 6)
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::core::config::EncodingMode;
use crate::core::dictionary::Dictionary;
fn make_base64_dict() -> Dictionary {
let chars: Vec<char> = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
.chars()
.collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, Some('=')).unwrap()
}
#[test]
fn test_encode_matches_scalar() {
let dictionary = make_base64_dict();
let test_data = b"Hello, World! This is a test of SIMD base64 encoding.";
if let Some(simd_result) = encode(test_data, &dictionary, DictionaryVariant::Base64Standard)
{
let scalar_result =
crate::encoders::algorithms::chunked::encode_chunked(test_data, &dictionary);
assert_eq!(
simd_result, scalar_result,
"SIMD and scalar should produce same output"
);
}
}
#[test]
fn test_encode_known_values() {
let dictionary = make_base64_dict();
let test_cases = [
(b"Hello".as_slice(), "SGVsbG8="),
(b"Hello, World!", "SGVsbG8sIFdvcmxkIQ=="),
(b"a", "YQ=="),
(b"ab", "YWI="),
(b"abc", "YWJj"),
(b"abcd", "YWJjZA=="),
(b"abcde", "YWJjZGU="),
(b"abcdef", "YWJjZGVm"),
(b"", ""),
];
for (input, expected) in test_cases {
if let Some(simd_result) = encode(input, &dictionary, DictionaryVariant::Base64Standard)
{
assert_eq!(simd_result, expected, "Failed for input: {:?}", input);
}
}
}
#[test]
fn test_decode_round_trip() {
let dictionary = make_base64_dict();
for len in 0..100 {
let original: Vec<u8> = (0..len).map(|i| (i * 7) as u8).collect();
if let Some(encoded) = encode(&original, &dictionary, DictionaryVariant::Base64Standard)
&& let Some(decoded) = decode(&encoded, DictionaryVariant::Base64Standard)
{
assert_eq!(decoded, original, "Round-trip failed at length {}", len);
}
}
}
#[test]
fn test_avx2_large_input() {
let dictionary = make_base64_dict();
let test_data: Vec<u8> = (0..48).collect();
if let Some(simd_result) =
encode(&test_data, &dictionary, DictionaryVariant::Base64Standard)
{
let scalar_result =
crate::encoders::algorithms::chunked::encode_chunked(&test_data, &dictionary);
assert_eq!(
simd_result, scalar_result,
"AVX2 path should match scalar output"
);
if let Some(decoded) = decode(&simd_result, DictionaryVariant::Base64Standard) {
assert_eq!(decoded, test_data, "AVX2 round-trip failed");
}
}
}
#[test]
fn test_decode_url_safe() {
let test_cases = [
("AQID-A__", vec![1, 2, 3, 248, 15, 255]), ("SGVsbG8tV29ybGQ", b"Hello-World".to_vec()),
];
for (input, expected) in test_cases {
if let Some(decoded) = decode(input, DictionaryVariant::Base64Url) {
assert_eq!(decoded, expected, "URL-safe decode failed for: {}", input);
} else {
panic!("Failed to decode URL-safe input: {}", input);
}
}
}
#[test]
fn test_avx2_decode_large() {
let dictionary = make_base64_dict();
let test_data: Vec<u8> = (0..48).map(|i| (i * 3) as u8).collect();
if let Some(encoded) = encode(&test_data, &dictionary, DictionaryVariant::Base64Standard) {
if let Some(decoded) = decode(&encoded, DictionaryVariant::Base64Standard) {
assert_eq!(decoded, test_data, "AVX2 decode failed");
}
}
}
}