use crate::error::{Result, ZiporaError};
use base64::engine::general_purpose;
use base64::Engine;
#[derive(Debug, Clone)]
pub struct Base64Config {
pub url_safe: bool,
pub padding: bool,
pub force_implementation: Option<SimdImplementation>,
}
impl Default for Base64Config {
fn default() -> Self {
Self {
url_safe: false,
padding: true,
force_implementation: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdImplementation {
Scalar,
SSE42,
AVX2,
AVX512,
NEON,
}
pub struct AdaptiveBase64 {
config: Base64Config,
}
impl AdaptiveBase64 {
pub fn new() -> Self {
Self::with_config(Base64Config::default())
}
pub fn with_config(config: Base64Config) -> Self {
Self { config }
}
fn engine(&self) -> impl Engine + '_ {
use base64::engine::general_purpose::*;
match (self.config.url_safe, self.config.padding) {
(false, true) => STANDARD,
(false, false) => STANDARD_NO_PAD,
(true, true) => URL_SAFE,
(true, false) => URL_SAFE_NO_PAD,
}
}
pub fn encode(&self, input: &[u8]) -> String {
self.engine().encode(input)
}
pub fn decode(&self, input: &str) -> Result<Vec<u8>> {
self.engine()
.decode(input)
.map_err(|e| ZiporaError::invalid_data(format!("base64 decode error: {}", e)))
}
pub fn selected_implementation(&self) -> SimdImplementation {
SimdImplementation::Scalar
}
}
impl Default for AdaptiveBase64 {
fn default() -> Self {
Self::new()
}
}
pub struct SimdBase64Encoder {
codec: AdaptiveBase64,
}
impl SimdBase64Encoder {
pub fn new() -> Self {
Self { codec: AdaptiveBase64::new() }
}
pub fn with_config(config: Base64Config) -> Self {
Self { codec: AdaptiveBase64::with_config(config) }
}
pub fn encode(&self, input: &[u8]) -> String {
self.codec.encode(input)
}
}
impl Default for SimdBase64Encoder {
fn default() -> Self { Self::new() }
}
pub struct SimdBase64Decoder {
codec: AdaptiveBase64,
}
impl SimdBase64Decoder {
pub fn new() -> Self {
Self { codec: AdaptiveBase64::new() }
}
pub fn with_config(config: Base64Config) -> Self {
Self { codec: AdaptiveBase64::with_config(config) }
}
pub fn decode(&self, input: &str) -> Result<Vec<u8>> {
self.codec.decode(input)
}
}
impl Default for SimdBase64Decoder {
fn default() -> Self { Self::new() }
}
pub fn base64_encode_simd(input: &[u8]) -> String {
general_purpose::STANDARD.encode(input)
}
pub fn base64_decode_simd(input: &str) -> Result<Vec<u8>> {
general_purpose::STANDARD
.decode(input)
.map_err(|e| ZiporaError::invalid_data(format!("base64 decode error: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_roundtrip() {
let codec = AdaptiveBase64::new();
let data = b"Hello, World!";
let encoded = codec.encode(data);
assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_url_safe() {
let config = Base64Config {
url_safe: true,
padding: true,
force_implementation: None,
};
let codec = AdaptiveBase64::with_config(config);
let data = b"\xfb\xff\xfe";
let encoded = codec.encode(data);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
}
#[test]
fn test_no_padding() {
let config = Base64Config {
url_safe: false,
padding: false,
force_implementation: None,
};
let codec = AdaptiveBase64::with_config(config);
let encoded = codec.encode(b"f");
assert!(!encoded.contains('='));
}
#[test]
fn test_convenience_functions() {
let input = b"foobar";
let encoded = base64_encode_simd(input);
assert_eq!(encoded, "Zm9vYmFy");
let decoded = base64_decode_simd(&encoded).unwrap();
assert_eq!(decoded, input);
}
#[test]
fn test_encoder_decoder_types() {
let encoder = SimdBase64Encoder::new();
let decoder = SimdBase64Decoder::new();
let data = b"test data 12345";
let encoded = encoder.encode(data);
let decoded = decoder.decode(&encoded).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_empty_input() {
let codec = AdaptiveBase64::new();
assert_eq!(codec.encode(b""), "");
assert_eq!(codec.decode("").unwrap(), Vec::<u8>::new());
}
#[test]
fn test_invalid_input() {
let codec = AdaptiveBase64::new();
assert!(codec.decode("!!!invalid!!!").is_err());
}
}