use crate::{
chunked::{ChunkedDecoder, ChunkedEncoder, ChunkedEncoderImpl},
padding::{Equals, Unpadded}
};
type Base64EncoderImpl<A, D, P> = ChunkedEncoder<A, D, P, Base64Impl, 64, 3, 4>;
type Base64DecoderImpl<D, P> = ChunkedDecoder<D, P, Base64Impl, 64, 3, 4>;
impl_encoding!(Base64Encoder, Base64Decoder, 64, Base64EncoderImpl, Base64DecoderImpl);
impl_encoder!(
Base64,
64,
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
Equals
);
impl_encoder!(
Base64Unpadded,
64,
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
Unpadded
);
impl_encoder!(
Base64UrlPadded,
64,
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
Equals
);
impl_encoder!(
Base64UrlUnpadded,
64,
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
Unpadded
);
#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
struct Base64Impl;
impl ChunkedEncoderImpl<3, 4> for Base64Impl {
#[inline]
fn padding_len(len: usize) -> usize {
debug_assert!(len <= 3);
(len * 8).div_ceil(6)
}
#[inline]
fn encoded_len(len: usize, padded: bool) -> Option<usize> {
if padded {
len.div_ceil(3).checked_mul(4)
} else {
let rem = len % 3;
(len / 3).checked_mul(4)?.checked_add(Self::padding_len(rem))
}
}
#[inline]
fn decoded_len(src: &[u8], padding: Option<u8>) -> Option<usize> {
if src.is_empty() {
return Some(0);
}
let (q, r) = if let Some(padding) = padding {
if !src.len().is_multiple_of(4) {
return None;
}
let padding = src.iter().rev().take(4).take_while(|&&b| b == padding).count();
if padding > 2 {
return None;
}
((src.len() / 4).saturating_sub(1), 4 - padding)
} else {
let r = src.len() % 4;
if r == 1 {
return None;
}
(src.len() / 4, r)
};
Some(q * 3 + r * 6 / 8)
}
#[inline]
fn encode_chunk_raw(src: &[u8; 3], dst: &mut [u8; 4]) {
let src = u32::from_be_bytes([0, src[0], src[1], src[2]]);
dst.iter_mut()
.enumerate()
.for_each(|(i, dst)| *dst = (src >> (6 * (3 - i))) as u8 & 0b11_1111);
}
#[inline]
fn decode_raw_chunk(src: &[u8; 4], dst: &mut [u8; 3]) -> Result<(), ()> {
let src = src.iter().fold(0, |acc, &b| acc << 6 | u32::from(b));
dst.iter_mut().enumerate().for_each(|(i, dst)| *dst = (src >> ((2 - i) * 8)) as u8);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DecodeIntoError;
fn test(src: &[u8], expected: &str) {
let mut buf = [0; 64];
assert_eq!(Base64.encoded_len(src.len()).unwrap(), expected.len());
assert_eq!(Base64.encode_into(src, &mut buf).unwrap(), expected);
assert_eq!(Base64.decode_into(expected.as_bytes(), &mut buf).unwrap(), src);
let expected = expected.trim_end_matches('=');
assert_eq!(Base64Unpadded.encoded_len(src.len()).unwrap(), expected.len());
assert_eq!(Base64Unpadded.encode_into(src, &mut buf).unwrap(), expected);
assert_eq!(Base64Unpadded.decode_into(expected.as_bytes(), &mut buf).unwrap(), src);
}
#[test]
fn rfc4648() {
test(b"", "");
test(b"f", "Zg==");
test(b"fo", "Zm8=");
test(b"foo", "Zm9v");
test(b"foob", "Zm9vYg==");
test(b"fooba", "Zm9vYmE=");
test(b"foobar", "Zm9vYmFy");
}
#[test]
fn wikipedia() {
test(b"Many hands make light work.", "TWFueSBoYW5kcyBtYWtlIGxpZ2h0IHdvcmsu");
test(b"Man", "TWFu");
test(b"Ma", "TWE=");
test(b"M", "TQ==");
test(b"light work.", "bGlnaHQgd29yay4=");
test(b"light work", "bGlnaHQgd29yaw==");
test(b"light wor", "bGlnaHQgd29y");
test(b"light wo", "bGlnaHQgd28=");
test(b"light w", "bGlnaHQgdw==");
}
#[test]
fn canonical() {
#[track_caller]
fn test_invalid(decoder: &dyn DynDecoder, input: &str, r: Result<&[u8], bool>) {
match decoder.decode_into(input.as_bytes(), &mut [0; 8]) {
Ok(s) => assert_eq!(s, r.unwrap()),
Err(DecodeIntoError::InvalidLength) => assert!(!r.unwrap_err()),
Err(DecodeIntoError::NonCanonical) => assert!(r.unwrap_err()),
r => panic!("{input}: {r:?}")
}
}
test(&[0x43, 0x33, 0x56, 0xc1], "QzNWwQ==");
test(b"Hello", "SGVsbG8=");
test(b"Hell", "SGVsbA==");
test_invalid(&Base64, "QzNWwc==", Err(true));
test_invalid(&Base64Unpadded, "QzNWwc", Err(true));
test_invalid(&Base64, "SGVsbG9=", Err(true));
test_invalid(&Base64Unpadded, "SGVsbG9", Err(true));
test_invalid(&Base64, "SGVsbG9", Err(false));
test_invalid(&Base64Unpadded, "SGVsbG9", Err(true));
test_invalid(&Base64, "SGVsbA=", Err(false));
test_invalid(&Base64, "SGVsbA====", Err(false));
test_invalid(&Base64, "SGVsbA===", Err(false));
}
}