const ENCODING_TABLE: &[u8; 1 << 6] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
const DECODING_TABLE: [u8; 1 << 8] = {
let mut table = [u8::MAX; 1 << 8];
let mut i = 0;
while i < ENCODING_TABLE.len() {
table[ENCODING_TABLE[i] as usize] = i as u8;
i += 1;
}
table
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Base64Error {
MalformedByte { byte: u8, position: usize },
InsufficientPadding,
}
impl std::fmt::Display for Base64Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for Base64Error {}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Base64Binary {
bin: Vec<u8>,
}
impl Base64Binary {
pub fn encode(iter: impl IntoIterator<Item = u8>) -> Self {
iter.into_iter().collect()
}
pub fn decode(&self) -> impl Iterator<Item = u8> + '_ {
assert!(self.bin.len() % 4 == 0);
self.bin.chunks_exact(4).flat_map(|chunk| {
let b0 = DECODING_TABLE[chunk[0] as usize];
let b1 = DECODING_TABLE[chunk[1] as usize];
let b2 = DECODING_TABLE[chunk[2] as usize];
let b3 = DECODING_TABLE[chunk[3] as usize];
let mut r0 = Some((b0 << 2) | (b1 >> 4));
let mut r1 = (b2 != u8::MAX).then_some(b1.wrapping_shl(4) | (b2 >> 2));
let mut r2 = (b3 != u8::MAX).then_some(b2.wrapping_shl(6) | b3);
std::iter::from_fn(move || r0.take().or_else(|| r1.take()).or_else(|| r2.take()))
})
}
pub fn from_encoded(
iter: impl IntoIterator<Item = u8>,
allow_whitespace: bool,
) -> Result<Self, Base64Error> {
let mut bin = vec![];
let mut pad = None;
for (position, byte) in iter.into_iter().enumerate() {
if allow_whitespace && byte.is_ascii_whitespace() {
continue;
}
if byte == b'=' {
pad.get_or_insert((bin.len(), position));
} else if DECODING_TABLE[byte as usize] == u8::MAX {
return Err(Base64Error::MalformedByte { byte, position });
}
bin.push(byte);
}
if bin.len() % 4 != 0 {
return Err(Base64Error::InsufficientPadding);
}
if let Some((pad, position)) = pad
&& (bin.len() - pad > 2 || bin[pad..].iter().any(|&b| b != b'='))
{
return Err(Base64Error::MalformedByte {
byte: b'=',
position,
});
}
Ok(Base64Binary { bin })
}
}
impl FromIterator<u8> for Base64Binary {
fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
let mut iter = iter.into_iter();
let mut bin = vec![];
while let Some(b0) = iter.next() {
bin.push(ENCODING_TABLE[(b0 >> 2) as usize]);
match iter.next() {
Some(b1) => {
bin.push(ENCODING_TABLE[(((b0 & 0x3) << 4) | (b1 >> 4)) as usize]);
match iter.next() {
Some(b2) => {
bin.push(ENCODING_TABLE[(((b1 & 0xF) << 2) | (b2 >> 6)) as usize]);
bin.push(ENCODING_TABLE[(b2 & 0x3F) as usize]);
}
None => {
bin.push(ENCODING_TABLE[((b1 & 0xF) << 2) as usize]);
bin.push(b'=');
}
}
}
None => {
bin.push(ENCODING_TABLE[((b0 & 0x3) << 4) as usize]);
bin.push(b'=');
bin.push(b'=');
}
}
}
Base64Binary { bin }
}
}
impl From<&str> for Base64Binary {
fn from(value: &str) -> Self {
value.bytes().collect()
}
}
macro_rules! impl_from_str_for_base64_binary {
( $( $t:ty ),* ) => {
$(
impl From<$t> for Base64Binary {
fn from(value: $t) -> Self {
value.bytes().collect()
}
}
impl From<&$t> for Base64Binary {
fn from(value: &$t) -> Self {
value.bytes().collect()
}
}
)*
};
}
impl_from_str_for_base64_binary!(
String,
Box<str>,
std::rc::Rc<str>,
std::sync::Arc<str>,
std::borrow::Cow<'_, str>
);
impl From<&[u8]> for Base64Binary {
fn from(value: &[u8]) -> Self {
value.iter().copied().collect()
}
}
macro_rules! impl_from_bytes_for_base64_binary {
( $( $t:ty ),* ) => {
$(
impl From<$t> for Base64Binary {
fn from(value: $t) -> Self {
value.iter().copied().collect()
}
}
impl From<&$t> for Base64Binary {
fn from(value: &$t) -> Self {
value.iter().copied().collect()
}
}
)*
};
}
impl_from_bytes_for_base64_binary!(
Vec<u8>,
Box<[u8]>,
std::rc::Rc<[u8]>,
std::sync::Arc<[u8]>,
std::borrow::Cow<'_, [u8]>
);
impl std::fmt::Debug for Base64Binary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self)
}
}
impl std::fmt::Display for Base64Binary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
unsafe {
write!(f, "{}", std::str::from_utf8_unchecked(&self.bin))
}
}
}
#[cfg(test)]
mod tests {
use std::hash::{BuildHasher, Hasher, RandomState};
use super::*;
fn xor_shift32(seed: u64) -> impl Iterator<Item = u32> {
let mut random = seed as u32;
std::iter::repeat_with(move || {
random ^= random << 13;
random ^= random >> 17;
random ^= random << 5;
random
})
}
fn bytes(seed: u64) -> impl Iterator<Item = u8> {
let mut generator = xor_shift32(seed);
let mut counter = 0;
let mut buf = [0u8; 4];
std::iter::from_fn(move || {
if counter == 4 {
let val = generator.next().unwrap();
buf = val.to_le_bytes();
counter = 0;
}
let ret = buf[counter];
counter += 1;
Some(ret)
})
}
#[test]
fn regression_tests() {
let state = RandomState::new().build_hasher();
let seed = state.finish();
let mut bytes = bytes(seed);
for _ in 0..10000 {
let len = bytes.next().unwrap() as usize;
let bytes = (0..len).map(|_| bytes.next().unwrap()).collect::<Vec<_>>();
let encoded = bytes.iter().copied().collect::<Base64Binary>();
let decoded = encoded.decode().collect::<Vec<_>>();
assert_eq!(bytes, decoded, "len: {},{}", bytes.len(), decoded.len());
let pad = encoded.bin.iter().filter(|c| **c == b'=').count();
match encoded.bin.as_slice() {
[.., b'=', b'='] => assert_eq!(pad, 2),
[.., b'='] => assert_eq!(pad, 1),
[..] => assert_eq!(pad, 0),
}
let encoded2 = Base64Binary::from_encoded(encoded.to_string().bytes(), false).unwrap();
assert_eq!(encoded, encoded2);
}
}
#[test]
fn encoded_bytes_tests() {
assert!(Base64Binary::from_encoded(*b"", false).is_ok());
let state = RandomState::new().build_hasher();
let seed = state.finish();
let mut bytes = bytes(seed);
for _ in 0..10000 {
let len = bytes.next().unwrap() as usize;
let len = len.div_ceil(4) * 4;
let bytes = bytes
.by_ref()
.filter(|b| DECODING_TABLE[*b as usize] != u8::MAX)
.take(len)
.collect::<Vec<_>>();
let encoded = Base64Binary::from_encoded(bytes, false);
assert!(encoded.is_ok());
}
}
#[test]
fn erroneous_encoded_bytes_tests() {
assert!(Base64Binary::from_encoded(*b"a", false).is_err());
assert!(Base64Binary::from_encoded(*b"aa", false).is_err());
assert!(Base64Binary::from_encoded(*b"aaa", false).is_err());
assert!(Base64Binary::from_encoded(*b"aaaaa", false).is_err());
assert!(Base64Binary::from_encoded(*b"aaaaaa", false).is_err());
assert!(Base64Binary::from_encoded(*b"aaaaaaa", false).is_err());
assert!(Base64Binary::from_encoded(*b"=", false).is_err());
assert!(Base64Binary::from_encoded(*b"==", false).is_err());
assert!(Base64Binary::from_encoded(*b"===", false).is_err());
assert!(Base64Binary::from_encoded(*b"====", false).is_err());
assert!(Base64Binary::from_encoded(*b"a=", false).is_err());
assert!(Base64Binary::from_encoded(*b"a==", false).is_err());
assert!(Base64Binary::from_encoded(*b"a===", false).is_err());
}
}