use std::{fmt, marker::PhantomData};
use base64::{
Engine,
engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig, general_purpose},
};
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
use zeroize::Zeroize;
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Base64<C = Standard, B = Vec<u8>> {
bytes: B,
_phantom_conf: PhantomData<fn(C) -> C>,
}
impl<C, B> Zeroize for Base64<C, B>
where
B: Zeroize,
{
fn zeroize(&mut self) {
self.bytes.zeroize();
}
}
pub trait Base64Config {
#[doc(hidden)]
const CONF: Conf;
}
#[doc(hidden)]
pub struct Conf(base64::alphabet::Alphabet);
#[non_exhaustive]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Standard;
impl Base64Config for Standard {
const CONF: Conf = Conf(base64::alphabet::STANDARD);
}
#[non_exhaustive]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct UrlSafe;
impl Base64Config for UrlSafe {
const CONF: Conf = Conf(base64::alphabet::URL_SAFE);
}
impl<C: Base64Config, B> Base64<C, B> {
const CONFIG: GeneralPurposeConfig = general_purpose::NO_PAD
.with_decode_allow_trailing_bits(true)
.with_decode_padding_mode(DecodePaddingMode::Indifferent);
const ENGINE: GeneralPurpose = GeneralPurpose::new(&C::CONF.0, Self::CONFIG);
}
impl<C: Base64Config, B: AsRef<[u8]>> Base64<C, B> {
pub fn new(bytes: B) -> Self {
Self { bytes, _phantom_conf: PhantomData }
}
pub fn as_bytes(&self) -> &[u8] {
self.bytes.as_ref()
}
pub fn encode(&self) -> String {
Self::ENGINE.encode(self.as_bytes())
}
}
impl<C, B> Base64<C, B> {
pub fn as_inner(&self) -> &B {
&self.bytes
}
pub fn into_inner(self) -> B {
self.bytes
}
}
impl<C: Base64Config> Base64<C> {
pub fn empty() -> Self {
Self::new(Vec::new())
}
}
impl<C: Base64Config, B: TryFromBase64DecodedBytes> Base64<C, B> {
pub fn parse(encoded: impl AsRef<[u8]>) -> Result<Self, Base64DecodeError> {
let decoded = Self::ENGINE.decode(encoded).map_err(Base64DecodeError::base64)?;
B::try_from_bytes(decoded).map(Self::new)
}
}
impl<C: Base64Config, B: AsRef<[u8]>> fmt::Debug for Base64<C, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.encode().fmt(f)
}
}
impl<C: Base64Config, B: AsRef<[u8]>> fmt::Display for Base64<C, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.encode().fmt(f)
}
}
impl<'de, C: Base64Config, B: TryFromBase64DecodedBytes> Deserialize<'de> for Base64<C, B> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let encoded = super::deserialize_cow_str(deserializer)?;
Self::parse(&*encoded).map_err(de::Error::custom)
}
}
impl<C: Base64Config, B: AsRef<[u8]>> Serialize for Base64<C, B> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.encode())
}
}
pub trait TryFromBase64DecodedBytes: Sized + AsRef<[u8]> {
#[doc(hidden)]
fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError>;
}
impl TryFromBase64DecodedBytes for Vec<u8> {
fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError> {
Ok(bytes)
}
}
impl<const N: usize> TryFromBase64DecodedBytes for [u8; N] {
fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError> {
Self::try_from(bytes)
.map_err(|bytes| Base64DecodeError::invalid_decoded_length(bytes.len(), N))
}
}
#[derive(Clone)]
pub struct Base64DecodeError(Base64DecodeErrorInner);
impl Base64DecodeError {
fn base64(error: base64::DecodeError) -> Self {
Self(Base64DecodeErrorInner::Base64(error))
}
fn invalid_decoded_length(len: usize, expected: usize) -> Self {
Self(Base64DecodeErrorInner::InvalidDecodedLength { len, expected })
}
}
impl fmt::Debug for Base64DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for Base64DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
Base64DecodeErrorInner::Base64(error) => write!(f, "invalid base64 encoding: {error}"),
Base64DecodeErrorInner::InvalidDecodedLength { len, expected } => {
write!(f, "invalid decoded base64 bytes length: {len}, expected {expected}")
}
}
}
}
impl std::error::Error for Base64DecodeError {}
#[derive(Debug, Clone)]
enum Base64DecodeErrorInner {
Base64(base64::DecodeError),
InvalidDecodedLength {
len: usize,
expected: usize,
},
}
#[cfg(test)]
mod tests {
use super::{Base64, Standard};
#[test]
fn parse_base64() {
const INPUT: &str = "3UmJnEIzUr2xWyaUnJg5fXwRybwG5FVC6Gq\
MHverEUn0ztuIsvVxX89JXX2pvdTsOBbLQx+4TVL02l4Cp5wPCm";
const INPUT_WITH_PADDING: &str = "im9+knCkMNQNh9o6sbdcZw==";
Base64::<Standard>::parse(INPUT).unwrap();
Base64::<Standard>::parse(INPUT_WITH_PADDING)
.expect("We should be able to decode padded Base64");
Base64::<Standard, [u8; 32]>::parse(INPUT).unwrap_err();
Base64::<Standard, [u8; 64]>::parse(INPUT).unwrap();
Base64::<Standard, [u8; 32]>::parse(INPUT_WITH_PADDING).unwrap_err();
Base64::<Standard, [u8; 16]>::parse(INPUT_WITH_PADDING).unwrap();
}
}