#![allow(unused_unsafe)]
use super::super::common;
use crate::core::dictionary::Dictionary;
use crate::simd::variants::DictionaryVariant;
#[cfg(target_arch = "aarch64")]
pub fn encode(data: &[u8], dictionary: &Dictionary, variant: DictionaryVariant) -> Option<String> {
let output_len = data.len().div_ceil(3) * 4;
let mut result = String::with_capacity(output_len);
unsafe {
encode_neon_impl(data, dictionary, variant, &mut result);
}
Some(result)
}
#[cfg(target_arch = "aarch64")]
pub fn decode(encoded: &str, variant: DictionaryVariant) -> Option<Vec<u8>> {
let encoded_bytes = encoded.as_bytes();
let input_no_padding = encoded.trim_end_matches('=');
let output_len = (input_no_padding.len() / 4) * 3
+ match input_no_padding.len() % 4 {
0 => 0,
2 => 1,
3 => 2,
_ => return None, };
let mut result = Vec::with_capacity(output_len);
unsafe {
if !decode_neon_impl(encoded_bytes, variant, &mut result) {
return None;
}
}
Some(result)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn encode_neon_impl(
data: &[u8],
dictionary: &Dictionary,
variant: DictionaryVariant,
result: &mut String,
) {
use std::arch::aarch64::*;
const BLOCK_SIZE: usize = 12;
if data.len() < 16 {
encode_scalar_remainder(data, dictionary, result);
return;
}
let safe_len = if data.len() >= 4 { data.len() - 4 } else { 0 };
let (num_rounds, simd_bytes) = common::calculate_blocks(safe_len, BLOCK_SIZE);
let mut offset = 0;
for _ in 0..num_rounds {
let input_vec = unsafe { vld1q_u8(data.as_ptr().add(offset)) };
let reshuffled = unsafe { reshuffle(input_vec) };
let encoded = unsafe { translate(reshuffled, variant) };
let mut output_buf = [0u8; 16];
unsafe {
vst1q_u8(output_buf.as_mut_ptr(), encoded);
}
for &byte in &output_buf {
result.push(byte as char);
}
offset += BLOCK_SIZE;
}
if simd_bytes < data.len() {
encode_scalar_remainder(&data[simd_bytes..], dictionary, result);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn reshuffle(input: std::arch::aarch64::uint8x16_t) -> std::arch::aarch64::uint8x16_t {
use std::arch::aarch64::*;
let shuffle_indices = unsafe {
vld1q_u8(
[
1, 0, 2, 1, 4, 3, 5, 4, 7, 6, 8, 7, 10, 9, 11, 10, ]
.as_ptr(),
)
};
let shuffled = unsafe { vqtbl1q_u8(input, shuffle_indices) };
let shuffled_u32 = unsafe { vreinterpretq_u32_u8(shuffled) };
let t0 = unsafe { vandq_u32(shuffled_u32, vdupq_n_u32(0x0FC0FC00)) };
let t1 = unsafe {
let t0_u16 = vreinterpretq_u16_u32(t0);
let mult_pattern = vreinterpretq_u16_u32(vdupq_n_u32(0x04000040));
let lo = vget_low_u16(t0_u16);
let hi = vget_high_u16(t0_u16);
let mult_lo = vget_low_u16(mult_pattern);
let mult_hi = vget_high_u16(mult_pattern);
let lo_32 = vmull_u16(lo, mult_lo);
let hi_32 = vmull_u16(hi, mult_hi);
let lo_result = vshrn_n_u32(lo_32, 16);
let hi_result = vshrn_n_u32(hi_32, 16);
vreinterpretq_u32_u16(vcombine_u16(lo_result, hi_result))
};
let t2 = unsafe { vandq_u32(shuffled_u32, vdupq_n_u32(0x003F03F0)) };
let t3 = unsafe {
let t2_u16 = vreinterpretq_u16_u32(t2);
let mult_pattern = vreinterpretq_u16_u32(vdupq_n_u32(0x01000010));
vreinterpretq_u32_u16(vmulq_u16(t2_u16, mult_pattern))
};
unsafe { vreinterpretq_u8_u32(vorrq_u32(t1, t3)) }
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn translate(
indices: std::arch::aarch64::uint8x16_t,
variant: DictionaryVariant,
) -> std::arch::aarch64::uint8x16_t {
use std::arch::aarch64::*;
let lut = match variant {
DictionaryVariant::Base64Standard => unsafe {
vld1q_u8(
[
65, 71, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252,
237, 240, 0, 0, ]
.as_ptr(),
)
},
DictionaryVariant::Base64Url => unsafe {
vld1q_u8(
[
65, 71, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252,
239, 32, 0, 0, ]
.as_ptr(),
)
},
};
let mut lut_indices = unsafe { vqsubq_u8(indices, vdupq_n_u8(51)) };
let indices_s8 = unsafe { vreinterpretq_s8_u8(indices) };
let mask = unsafe { vcgtq_s8(indices_s8, vdupq_n_s8(25)) };
lut_indices = unsafe { vsubq_u8(lut_indices, mask) };
let offsets = unsafe { vqtbl1q_u8(lut, lut_indices) };
unsafe { vaddq_u8(indices, offsets) }
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn decode_neon_impl(
encoded: &[u8],
variant: DictionaryVariant,
result: &mut Vec<u8>,
) -> bool {
use std::arch::aarch64::*;
const INPUT_BLOCK_SIZE: usize = 16;
const OUTPUT_BLOCK_SIZE: usize = 12;
let input_no_padding = if let Some(last_non_pad) = encoded.iter().rposition(|&b| b != b'=') {
&encoded[..=last_non_pad]
} else {
encoded
};
let (lut_lo, lut_hi, lut_roll) = unsafe { get_decode_luts(variant) };
let (num_rounds, simd_bytes) =
common::calculate_blocks(input_no_padding.len(), INPUT_BLOCK_SIZE);
for round in 0..num_rounds {
let offset = round * INPUT_BLOCK_SIZE;
let input_vec = unsafe { vld1q_u8(input_no_padding.as_ptr().add(offset)) };
if !unsafe { validate(input_vec, lut_lo, lut_hi) } {
return false; }
let decoded = unsafe {
let indices = translate_decode(input_vec, lut_hi, lut_roll);
reshuffle_decode(indices)
};
let mut output_buf = [0u8; 16];
unsafe {
vst1q_u8(output_buf.as_mut_ptr(), decoded);
}
result.extend_from_slice(&output_buf[0..OUTPUT_BLOCK_SIZE]);
}
if simd_bytes < input_no_padding.len() {
let remainder = &input_no_padding[simd_bytes..];
if !decode_scalar_remainder(
remainder,
&mut |c| match c {
b'A'..=b'Z' => Some(c - b'A'),
b'a'..=b'z' => Some(c - b'a' + 26),
b'0'..=b'9' => Some(c - b'0' + 52),
b'+' if matches!(variant, DictionaryVariant::Base64Standard) => Some(62),
b'/' if matches!(variant, DictionaryVariant::Base64Standard) => Some(63),
b'-' if matches!(variant, DictionaryVariant::Base64Url) => Some(62),
b'_' if matches!(variant, DictionaryVariant::Base64Url) => Some(63),
_ => None,
},
result,
) {
return false;
}
}
true
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn get_decode_luts(
variant: DictionaryVariant,
) -> (
std::arch::aarch64::uint8x16_t,
std::arch::aarch64::uint8x16_t,
std::arch::aarch64::uint8x16_t,
) {
use std::arch::aarch64::*;
let lut_lo = unsafe {
vld1q_u8(
[
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B,
0x1B, 0x1A,
]
.as_ptr(),
)
};
let lut_hi = unsafe {
vld1q_u8(
[
0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x10, 0x10,
]
.as_ptr(),
)
};
let lut_roll = match variant {
DictionaryVariant::Base64Standard => unsafe {
vld1q_u8(
[
0, 16, 19, 4, 191, 191, 185, 185, 0, 0, 0, 0, 0, 0, 0,
0, ]
.as_ptr(),
)
},
DictionaryVariant::Base64Url => unsafe {
vld1q_u8(
[
0, 17, 224, 4, 191, 191, 185, 185, 0, 0, 0, 0, 0, 0, 0,
0, ]
.as_ptr(),
)
},
};
(lut_lo, lut_hi, lut_roll)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn validate(
input: std::arch::aarch64::uint8x16_t,
lut_lo: std::arch::aarch64::uint8x16_t,
lut_hi: std::arch::aarch64::uint8x16_t,
) -> bool {
use std::arch::aarch64::*;
let lo_nibbles = unsafe { vandq_u8(input, vdupq_n_u8(0x0F)) };
let hi_nibbles_shifted = unsafe { vshrq_n_u8(input, 4) };
let hi_nibbles = unsafe { vandq_u8(hi_nibbles_shifted, vdupq_n_u8(0x0F)) };
let lo_lookup = unsafe { vqtbl1q_u8(lut_lo, lo_nibbles) };
let hi_lookup = unsafe { vqtbl1q_u8(lut_hi, hi_nibbles) };
let validation = unsafe { vandq_u8(lo_lookup, hi_lookup) };
unsafe { vmaxvq_u8(validation) == 0 }
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn translate_decode(
input: std::arch::aarch64::uint8x16_t,
_lut_hi: std::arch::aarch64::uint8x16_t,
lut_roll: std::arch::aarch64::uint8x16_t,
) -> std::arch::aarch64::uint8x16_t {
use std::arch::aarch64::*;
let hi_nibbles_shifted = unsafe { vshrq_n_u8(input, 4) };
let hi_nibbles = unsafe { vandq_u8(hi_nibbles_shifted, vdupq_n_u8(0x0F)) };
let eq_2f = unsafe { vceqq_u8(input, vdupq_n_u8(0x2F)) };
let roll_index = unsafe { vaddq_u8(eq_2f, hi_nibbles) };
let offsets = unsafe { vqtbl1q_u8(lut_roll, roll_index) };
unsafe { vaddq_u8(input, offsets) }
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn reshuffle_decode(
indices: std::arch::aarch64::uint8x16_t,
) -> std::arch::aarch64::uint8x16_t {
use std::arch::aarch64::*;
let _indices_u32 = vreinterpretq_u32_u8(indices);
let even_mask = unsafe {
vld1q_u8(
[
0, 255, 2, 255, 4, 255, 6, 255, 8, 255, 10, 255, 12, 255, 14, 255,
]
.as_ptr(),
)
};
let odd_mask = unsafe {
vld1q_u8(
[
1, 255, 3, 255, 5, 255, 7, 255, 9, 255, 11, 255, 13, 255, 15, 255,
]
.as_ptr(),
)
};
let even_bytes = vqtbl1q_u8(indices, even_mask);
let odd_bytes = vqtbl1q_u8(indices, odd_mask);
let even_u16 = vreinterpretq_u16_u8(even_bytes);
let odd_u16 = vreinterpretq_u16_u8(odd_bytes);
let merged_u16 = vaddq_u16(vshlq_n_u16(even_u16, 6), odd_u16);
let lo_pair_mask = unsafe {
vld1q_u8(
[
0, 1, 255, 255, 4, 5, 255, 255, 8, 9, 255, 255, 12, 13, 255, 255,
]
.as_ptr(),
)
};
let hi_pair_mask = unsafe {
vld1q_u8(
[
2, 3, 255, 255, 6, 7, 255, 255, 10, 11, 255, 255, 14, 15, 255, 255,
]
.as_ptr(),
)
};
let lo_pairs = vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u16(merged_u16), lo_pair_mask));
let hi_pairs = vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u16(merged_u16), hi_pair_mask));
let final_u32 = vaddq_u32(vshlq_n_u32(lo_pairs, 12), hi_pairs);
let shuffle =
unsafe { vld1q_u8([2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, 255, 255, 255, 255].as_ptr()) };
unsafe { vqtbl1q_u8(vreinterpretq_u8_u32(final_u32), shuffle) }
}
fn encode_scalar_remainder(data: &[u8], dictionary: &Dictionary, result: &mut String) {
common::encode_scalar_chunked(data, dictionary, result);
}
fn decode_scalar_remainder(
data: &[u8],
char_to_index: &mut dyn FnMut(u8) -> Option<u8>,
result: &mut Vec<u8>,
) -> bool {
common::decode_scalar_chunked(data, char_to_index, result, 6)
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::core::config::EncodingMode;
use crate::core::dictionary::Dictionary;
fn make_base64_dict() -> Dictionary {
let chars: Vec<char> = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
.chars()
.collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, Some('=')).unwrap()
}
fn make_base64_url_dict() -> Dictionary {
let chars: Vec<char> = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
.chars()
.collect();
Dictionary::new_with_mode(chars, EncodingMode::Chunked, Some('=')).unwrap()
}
#[test]
fn test_encode_basic() {
let dict = make_base64_dict();
let input = b"Hello, World!";
let result = encode(input, &dict, DictionaryVariant::Base64Standard);
assert!(result.is_some());
let encoded = result.unwrap();
assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
}
#[test]
fn test_encode_url_safe() {
let dict = make_base64_url_dict();
let input = b"\xfb\xff\xfe";
let result = encode(input, &dict, DictionaryVariant::Base64Url);
assert!(result.is_some());
let encoded = result.unwrap();
assert_eq!(encoded, "-__-");
}
#[test]
fn test_decode_basic() {
let encoded = "SGVsbG8sIFdvcmxkIQ==";
let result = decode(encoded, DictionaryVariant::Base64Standard);
assert!(result.is_some());
let decoded = result.unwrap();
assert_eq!(decoded, b"Hello, World!");
}
#[test]
fn test_decode_url_safe() {
let encoded = "-__-";
let result = decode(encoded, DictionaryVariant::Base64Url);
assert!(result.is_some());
let decoded = result.unwrap();
assert_eq!(decoded, b"\xfb\xff\xfe");
}
#[test]
fn test_round_trip() {
let dict = make_base64_dict();
let inputs: Vec<&[u8]> = vec![
b"",
b"f",
b"fo",
b"foo",
b"foob",
b"fooba",
b"foobar",
b"The quick brown fox jumps over the lazy dog",
];
for input in inputs {
let encoded =
encode(input, &dict, DictionaryVariant::Base64Standard).expect("encode failed");
let decoded =
decode(&encoded, DictionaryVariant::Base64Standard).expect("decode failed");
assert_eq!(decoded, input, "round-trip failed for input: {:?}", input);
}
}
#[test]
fn test_invalid_decode() {
let result = decode("SGVs!G8=", DictionaryVariant::Base64Standard);
assert!(result.is_none());
let result = decode("SGVs", DictionaryVariant::Base64Standard);
let _ = result;
}
#[test]
fn test_large_input() {
let dict = make_base64_dict();
let input = vec![42u8; 1024];
let encoded =
encode(&input, &dict, DictionaryVariant::Base64Standard).expect("encode failed");
let decoded = decode(&encoded, DictionaryVariant::Base64Standard).expect("decode failed");
assert_eq!(decoded, input);
}
}