use crate::prelude::*;
pub(crate) fn decode(input: &str) -> Result<Vec<u8>, &'static str> {
static LUT: [i8; 256] = build_lut();
let mut buf: Vec<u8> = Vec::with_capacity(input.len());
for &b in input.as_bytes() {
if matches!(b, b' ' | b'\t' | b'\r' | b'\n') {
continue;
}
buf.push(b);
}
if buf.len() % 4 != 0 {
return Err("base64 length not a multiple of 4 (after stripping whitespace)");
}
if buf.is_empty() {
return Ok(Vec::new());
}
let mut out: Vec<u8> = Vec::with_capacity(buf.len() / 4 * 3);
let mut i = 0;
while i < buf.len() {
let q0 = LUT[buf[i] as usize];
let q1 = LUT[buf[i + 1] as usize];
let q2 = LUT[buf[i + 2] as usize];
let q3 = LUT[buf[i + 3] as usize];
if q0 < 0 || q1 < 0 {
return Err("invalid base64 character (alphabet)");
}
let b0 = ((q0 as u32) << 18) | ((q1 as u32) << 12);
match (q2, q3) {
(q2, q3) if q2 >= 0 && q3 >= 0 => {
let v = b0 | ((q2 as u32) << 6) | (q3 as u32);
out.push(((v >> 16) & 0xff) as u8);
out.push(((v >> 8) & 0xff) as u8);
out.push((v & 0xff) as u8);
}
(-2, -2) if i + 4 == buf.len() => {
out.push(((b0 >> 16) & 0xff) as u8);
}
(q2, -2) if q2 >= 0 && i + 4 == buf.len() => {
let v = b0 | ((q2 as u32) << 6);
out.push(((v >> 16) & 0xff) as u8);
out.push(((v >> 8) & 0xff) as u8);
}
_ => return Err("invalid base64 padding"),
}
i += 4;
}
Ok(out)
}
pub(crate) fn encode(input: &[u8]) -> String {
const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
let chunks = input.chunks_exact(3);
let remainder = chunks.remainder();
for chunk in chunks {
let n = ((chunk[0] as u32) << 16) | ((chunk[1] as u32) << 8) | (chunk[2] as u32);
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 6) & 0x3f) as usize] as char);
out.push(ALPHABET[(n & 0x3f) as usize] as char);
}
match remainder {
[] => {}
[a] => {
let n = (*a as u32) << 16;
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push('=');
out.push('=');
}
[a, b] => {
let n = ((*a as u32) << 16) | ((*b as u32) << 8);
out.push(ALPHABET[((n >> 18) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 12) & 0x3f) as usize] as char);
out.push(ALPHABET[((n >> 6) & 0x3f) as usize] as char);
out.push('=');
}
_ => unreachable!("chunks_exact(3) leaves 0..=2 bytes"),
}
out
}
const fn build_lut() -> [i8; 256] {
let mut lut = [-1_i8; 256];
let mut i = 0;
while i < 26 {
lut[b'A' as usize + i] = i as i8;
lut[b'a' as usize + i] = (i + 26) as i8;
i += 1;
}
let mut i = 0;
while i < 10 {
lut[b'0' as usize + i] = (i + 52) as i8;
i += 1;
}
lut[b'+' as usize] = 62;
lut[b'/' as usize] = 63;
lut[b'=' as usize] = -2; lut
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_empty() {
assert!(decode("").unwrap().is_empty());
assert!(encode(&[]).is_empty());
}
#[test]
fn roundtrip_one_byte() {
for b in 0u8..=255 {
let s = encode(&[b]);
assert_eq!(s.len(), 4);
assert_eq!(decode(&s).unwrap(), vec![b]);
}
}
#[test]
fn roundtrip_two_bytes() {
let bytes = [0xab, 0xcd];
let s = encode(&bytes);
assert_eq!(s, "q80=");
assert_eq!(decode(&s).unwrap(), bytes);
}
#[test]
fn roundtrip_three_bytes() {
let bytes = [0x01, 0x02, 0x03];
let s = encode(&bytes);
assert_eq!(s, "AQID");
assert_eq!(decode(&s).unwrap(), bytes);
}
#[test]
fn roundtrip_hello() {
let s = encode(b"Hello, World!");
assert_eq!(s, "SGVsbG8sIFdvcmxkIQ==");
assert_eq!(decode(&s).unwrap(), b"Hello, World!");
}
#[test]
fn decode_tolerates_whitespace_and_newlines() {
let s = "SGVs\n bG8s\n IFdv\n cmxk\n IQ==\n";
assert_eq!(decode(s).unwrap(), b"Hello, World!");
}
#[test]
fn decode_rejects_invalid_alphabet() {
assert!(decode("**==").is_err());
}
#[test]
fn decode_rejects_bad_padding_position() {
assert!(decode("AB==CDEF").is_err());
}
#[test]
fn decode_rejects_bad_length() {
assert!(decode("ABC").is_err());
}
#[test]
fn roundtrip_random_byte_pattern() {
let bytes: Vec<u8> = (0..=255_u8).cycle().take(1023).collect();
let s = encode(&bytes);
assert_eq!(decode(&s).unwrap(), bytes);
}
}