use std::marker::PhantomData;
use bitreader::BitReader;
use crate::alphabet::Alphabet;
#[derive(Debug, Clone)]
pub struct Base64String<A> {
content: String,
_alphabet: PhantomData<A>,
}
impl<A> Base64String<A>
where
A: Alphabet,
{
pub fn encode(bytes: &[u8]) -> Self {
let mut chunks = bytes.chunks_exact(3);
let mut encoded = vec![];
#[allow(clippy::while_let_on_iterator)] while let Some(chunk) = chunks.next() {
encoded.push(Self::encode_triplet([chunk[0], chunk[1], chunk[2]]))
}
let rem = chunks.remainder();
match rem.len() {
0 => { }
1 => encoded.push(Self::encode_singlet([rem[0]])),
2 => encoded.push(Self::encode_doublet([rem[0], rem[1]])),
_ => unreachable!("{}", rem.len()), }
Self {
content: encoded.iter().flatten().collect(),
_alphabet: PhantomData,
}
}
pub fn decode(&self) -> Vec<u8> {
let mut decoded = vec![];
let tmp = self.content.chars().collect::<Vec<_>>();
let segments = tmp.chunks_exact(4);
for seg in segments {
if seg.ends_with(&[A::PADDING, A::PADDING]) {
let tri = Self::decode_quad([seg[0], seg[1], 0 as char, 0 as char]);
decoded.push(tri[0]);
} else if seg.ends_with(&[A::PADDING]) {
let tri = Self::decode_quad([seg[0], seg[1], seg[2], 0 as char]);
decoded.push(tri[0]);
decoded.push(tri[1]);
} else {
let tri = Self::decode_quad([seg[0], seg[1], seg[2], seg[3]]);
for byte in tri {
decoded.push(byte)
}
}
}
decoded
}
pub fn from_encoded(b64: &str) -> Self {
let mut content = b64.to_string();
while content.len() % 4 != 0 {
content.push(A::PADDING)
}
Self {
content,
_alphabet: PhantomData::<A>,
}
}
pub fn without_padding(&self) -> String {
self.content.chars().filter(|&c| c != A::PADDING).collect()
}
pub fn change_alphabet<B>(self) -> Base64String<B>
where
B: Alphabet,
{
let inner = self.decode();
Base64String::<B>::encode(&inner)
}
fn decode_quad([a, b, c, d]: [char; 4]) -> [u8; 3] {
let concat_bytes = ((A::decode_char(a) as u32) << 18)
| ((A::decode_char(b) as u32) << 12)
| ((A::decode_char(c) as u32) << 6)
| A::decode_char(d) as u32;
[
((concat_bytes >> 16) & 0b1111_1111) as u8,
((concat_bytes >> 8) & 0b1111_1111) as u8,
(concat_bytes & 0b1111_1111) as u8,
]
}
fn encode_triplet(triple: [u8; 3]) -> [char; 4] {
let mut reader = BitReader::new(&triple);
let first = reader.read_u8(6).unwrap();
let second = reader.read_u8(6).unwrap();
let third = reader.read_u8(6).unwrap();
let fourth = reader.read_u8(6).unwrap();
[
A::encode_bits(first),
A::encode_bits(second),
A::encode_bits(third),
A::encode_bits(fourth),
]
}
fn encode_singlet(rem: [u8; 1]) -> [char; 4] {
let mut reader = BitReader::new(&rem);
let six = reader.read_u8(6).unwrap();
let half_nib = reader.read_u8(2).unwrap();
let (half_nib, _) = half_nib.overflowing_shl(4);
let first = A::encode_bits(six);
let second = A::encode_bits(half_nib);
[first, second, A::PADDING, A::PADDING]
}
fn encode_doublet(rem: [u8; 2]) -> [char; 4] {
let mut reader = BitReader::new(&rem);
let six1 = reader.read_u8(6).unwrap();
let six2 = reader.read_u8(6).unwrap();
let nibble = reader.read_u8(4).unwrap();
let (nibble, _) = nibble.overflowing_shl(2);
let first = A::encode_bits(six1);
let second = A::encode_bits(six2);
let third = A::encode_bits(nibble);
[first, second, third, A::PADDING]
}
}
impl<A> core::fmt::Display for Base64String<A>
where
A: Alphabet,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.content)
}
}
impl<A> PartialEq for Base64String<A>
where
A: Alphabet,
{
fn eq(&self, other: &Self) -> bool {
self.content == other.content
}
}
#[cfg(test)]
mod tests {
use crate::alphabet::Standard;
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn encode_triplet() {
let triplet = ['A', 'B', 'C'];
let expected_encoded = ['Q', 'U', 'J', 'D'];
let encoded = Base64String::<Standard>::encode_triplet([
triplet[0] as u8,
triplet[1] as u8,
triplet[2] as u8,
]);
assert_eq!(encoded, expected_encoded);
}
#[test]
fn encode_long() {
let input = "everybody".chars().map(|c| c as u8);
let b64 = Base64String::encode(&input.collect::<Vec<_>>());
let expected = Base64String {
content: String::from("ZXZlcnlib2R5"),
_alphabet: PhantomData::<Standard>,
};
assert_eq!(b64, expected)
}
#[test]
fn encode_2_rem() {
let input = "event".chars().map(|c| c as u8);
let b64 = Base64String::encode(&input.collect::<Vec<_>>());
let expected = Base64String {
content: String::from("ZXZlbnQ="),
_alphabet: PhantomData::<Standard>,
};
assert_eq!(b64, expected)
}
#[test]
fn encode_1_rem() {
let input = "even".chars().map(|c| c as u8);
let b64 = Base64String::encode(&input.collect::<Vec<_>>());
let expected = Base64String {
content: String::from("ZXZlbg=="),
_alphabet: PhantomData::<Standard>,
};
assert_eq!(b64, expected)
}
#[test]
fn decode_no_pad() {
let src = Base64String {
content: String::from("ZXZlcnlib2R5"),
_alphabet: PhantomData::<Standard>,
};
let expected = b"everybody".to_vec();
let decoded = src.decode();
assert_eq!(decoded, expected)
}
#[test]
fn decode_one_pad() {
let src = Base64String {
content: String::from("ZXZlbnQ="),
_alphabet: PhantomData::<Standard>,
};
let expected = b"event".to_vec();
let decoded = src.decode();
assert_eq!(decoded, expected)
}
#[test]
fn decode_two_pad() {
let src = Base64String {
content: String::from("ZXZlbg=="),
_alphabet: PhantomData::<Standard>,
};
let expected = b"even".to_vec();
let decoded = src.decode();
assert_eq!(decoded, expected)
}
}