use crate::Result;
const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
pub fn decode(input: &str) -> Result<Vec<u8>> {
let mut lookup = [0xFFu8; 256];
for (i, &b) in ALPHABET.iter().enumerate() {
lookup[b as usize] = i as u8;
}
let mut out = Vec::with_capacity(input.len() * 3 / 4);
let mut acc: u32 = 0;
let mut groups: u32 = 0;
let mut pad: u32 = 0;
for &c in input.as_bytes() {
match c {
b' ' | b'\t' | b'\r' | b'\n' => continue,
b'=' => {
pad += 1;
if pad > 2 {
return Err(crate::Error::InvalidImage(
"dmg: base64 has more than two '=' padding bytes".into(),
));
}
acc <<= 6;
groups += 1;
}
_ => {
if pad > 0 {
return Err(crate::Error::InvalidImage(
"dmg: base64 has non-padding bytes after '='".into(),
));
}
let v = lookup[c as usize];
if v == 0xFF {
return Err(crate::Error::InvalidImage(format!(
"dmg: base64 contains invalid byte {c:#x}"
)));
}
acc = (acc << 6) | (v as u32);
groups += 1;
}
}
if groups == 4 {
out.push(((acc >> 16) & 0xFF) as u8);
out.push(((acc >> 8) & 0xFF) as u8);
out.push((acc & 0xFF) as u8);
acc = 0;
groups = 0;
}
}
if groups != 0 {
return Err(crate::Error::InvalidImage(
"dmg: base64 input length not a multiple of 4 after stripping whitespace".into(),
));
}
for _ in 0..pad {
out.pop();
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decodes_known_vectors() {
assert_eq!(decode("").unwrap(), b"");
assert_eq!(decode("Zg==").unwrap(), b"f");
assert_eq!(decode("Zm8=").unwrap(), b"fo");
assert_eq!(decode("Zm9v").unwrap(), b"foo");
assert_eq!(decode("Zm9vYg==").unwrap(), b"foob");
assert_eq!(decode("Zm9vYmE=").unwrap(), b"fooba");
assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
}
#[test]
fn tolerates_whitespace_and_line_breaks() {
let s = "Zm9v\n\tYmFy\r\n";
assert_eq!(decode(s).unwrap(), b"foobar");
}
#[test]
fn rejects_invalid_bytes() {
assert!(decode("Zm9v!").is_err());
assert!(decode("Zm===").is_err()); assert!(decode("Zm9vA").is_err()); }
}