use alloc::string::String;
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Base64Error {
InvalidCharacter,
InvalidPadding,
}
const BASE64_ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
pub(crate) fn encode(input: &[u8]) -> String {
let mut result = String::new();
let mut i = 0;
while i < input.len() {
let b0 = input[i];
let b1 = if i + 1 < input.len() { input[i + 1] } else { 0 };
let b2 = if i + 2 < input.len() { input[i + 2] } else { 0 };
let n = ((b0 as u32) << 16) | ((b1 as u32) << 8) | (b2 as u32);
result.push(BASE64_ALPHABET[(n >> 18 & 0x3F) as usize] as char);
result.push(BASE64_ALPHABET[(n >> 12 & 0x3F) as usize] as char);
if i + 1 < input.len() {
result.push(BASE64_ALPHABET[(n >> 6 & 0x3F) as usize] as char);
} else {
result.push('=');
}
if i + 2 < input.len() {
result.push(BASE64_ALPHABET[(n & 0x3F) as usize] as char);
} else {
result.push('=');
}
i += 3;
}
result
}
pub(crate) fn decode(input: &str) -> Result<Vec<u8>, Base64Error> {
let mut normalized: Vec<u8> = Vec::with_capacity(input.len());
for c in input.chars() {
match c {
' ' | '\t' | '\n' | '\r' => continue,
_ => {
if !c.is_ascii() {
return Err(Base64Error::InvalidCharacter);
}
normalized.push(c as u8);
}
}
}
if normalized.is_empty() {
return Ok(Vec::new());
}
let pad_count = normalized.iter().rev().take_while(|&&b| b == b'=').count();
if pad_count > 2 {
return Err(Base64Error::InvalidPadding);
}
if !normalized.len().is_multiple_of(4) {
return Err(Base64Error::InvalidPadding);
}
let data = &normalized[..normalized.len() - pad_count];
let last_block_chars = data.len() % 4;
let valid = match pad_count {
0 => last_block_chars == 0,
1 => last_block_chars == 3,
2 => last_block_chars == 2,
_ => false,
};
if !valid {
return Err(Base64Error::InvalidPadding);
}
let mut result = Vec::with_capacity((data.len() * 3) / 4);
let mut buf: u32 = 0;
let mut bits: u32 = 0;
for &b in data {
let val = match b {
b'A'..=b'Z' => (b - b'A') as u32,
b'a'..=b'z' => (b - b'a') as u32 + 26,
b'0'..=b'9' => (b - b'0') as u32 + 52,
b'+' => 62,
b'/' => 63,
_ => return Err(Base64Error::InvalidCharacter),
};
buf = (buf << 6) | val;
bits += 6;
if bits >= 8 {
bits -= 8;
result.push((buf >> bits) as u8);
buf &= (1 << bits) - 1;
}
}
if buf != 0 {
return Err(Base64Error::InvalidPadding);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_basic() {
assert_eq!(encode(b""), "");
assert_eq!(encode(b"f"), "Zg==");
assert_eq!(encode(b"fo"), "Zm8=");
assert_eq!(encode(b"foo"), "Zm9v");
assert_eq!(encode(b"foob"), "Zm9vYg==");
assert_eq!(encode(b"fooba"), "Zm9vYmE=");
assert_eq!(encode(b"foobar"), "Zm9vYmFy");
assert_eq!(encode(b"user:password"), "dXNlcjpwYXNzd29yZA==");
}
#[test]
fn decode_basic() {
assert_eq!(decode("").unwrap(), b"");
assert_eq!(decode("Zg==").unwrap(), b"f");
assert_eq!(decode("Zm8=").unwrap(), b"fo");
assert_eq!(decode("Zm9v").unwrap(), b"foo");
assert_eq!(decode("Zm9vYg==").unwrap(), b"foob");
assert_eq!(decode("Zm9vYmE=").unwrap(), b"fooba");
assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
assert_eq!(decode("dXNlcjpwYXNzd29yZA==").unwrap(), b"user:password");
}
#[test]
fn decode_ignores_whitespace() {
assert_eq!(decode("Zm9v\n").unwrap(), b"foo");
assert_eq!(decode("Z m 9 v").unwrap(), b"foo");
assert_eq!(decode("Zm9v\r\n").unwrap(), b"foo");
assert_eq!(decode("Zm9v\t").unwrap(), b"foo");
}
#[test]
fn decode_rejects_invalid_character() {
assert_eq!(decode("Zm9*"), Err(Base64Error::InvalidCharacter));
assert_eq!(decode("Z\u{00A0}m9v"), Err(Base64Error::InvalidCharacter));
}
#[test]
fn decode_rejects_unpadded_short_input() {
assert_eq!(decode("A"), Err(Base64Error::InvalidPadding));
assert_eq!(decode("Zg"), Err(Base64Error::InvalidPadding));
assert_eq!(decode("Zm8"), Err(Base64Error::InvalidPadding));
}
#[test]
fn decode_rejects_excess_padding() {
assert_eq!(decode("Zg==="), Err(Base64Error::InvalidPadding));
assert_eq!(decode("===="), Err(Base64Error::InvalidPadding));
}
#[test]
fn decode_rejects_mismatched_padding() {
assert_eq!(decode("Zg="), Err(Base64Error::InvalidPadding));
assert_eq!(decode("Zm8=="), Err(Base64Error::InvalidPadding));
assert_eq!(decode("Z=="), Err(Base64Error::InvalidPadding));
}
#[test]
fn decode_rejects_non_zero_trailing_bits() {
let result = decode("Zh==");
assert_eq!(result, Err(Base64Error::InvalidPadding));
}
}