use std::collections::VecDeque;
use std::io::Write;
use anyhow::{Context, Result};
use base64::Engine;
const B64_CHUNK: usize = 4 * 768;
pub struct B64Decoder<W: Write> {
ring: VecDeque<u8>,
out: W,
}
impl<W: Write> B64Decoder<W> {
pub fn new(out: W) -> Self {
Self {
ring: VecDeque::with_capacity(B64_CHUNK + 256),
out,
}
}
pub fn push_line(&mut self, line: &str) -> Result<bool> {
if !line
.bytes()
.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'+' | b'/' | b'='))
{
return Ok(false);
}
self.ring.extend(line.bytes());
self.flush_chunks()?;
Ok(true)
}
fn flush_chunks(&mut self) -> Result<()> {
while self.ring.len() >= B64_CHUNK {
let chunk: Vec<u8> = self.ring.drain(..B64_CHUNK).collect();
let decoded = base64::engine::general_purpose::STANDARD
.decode(&chunk)
.context("base64 chunk decode failed")?;
self.out.write_all(&decoded).context("write failed")?;
}
Ok(())
}
pub fn finish(mut self) -> Result<W> {
if !self.ring.is_empty() {
let tail: Vec<u8> = self.ring.drain(..).collect();
let decoded = base64::engine::general_purpose::STANDARD
.decode(&tail)
.context("base64 tail decode failed")?;
self.out.write_all(&decoded).context("write failed")?;
}
Ok(self.out)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn decode_lines(lines: &[&str]) -> Result<(bool, Vec<u8>)> {
let mut dec = B64Decoder::new(Vec::new());
for line in lines {
if !dec.push_line(line)? {
return Ok((false, vec![]));
}
}
let out = dec.finish()?;
Ok((true, out))
}
#[test]
fn decodes_single_line() {
let (ok, out) = decode_lines(&["aGVsbG8="]).unwrap();
assert!(ok);
assert_eq!(out, b"hello");
}
#[test]
fn decodes_multiline() {
let (ok, out) = decode_lines(&["aGVs", "bG8="]).unwrap();
assert!(ok);
assert_eq!(out, b"hello");
}
#[test]
fn rejects_non_base64_character() {
let (ok, _) = decode_lines(&["aGVs bG8="]).unwrap();
assert!(!ok);
}
#[test]
fn rejects_non_base64_on_later_line() {
let (ok, _) = decode_lines(&["aGVs", "bG8= extra"]).unwrap();
assert!(!ok);
}
#[test]
fn decodes_exactly_one_chunk() {
let line = "AAAA".repeat(B64_CHUNK / 4);
let (ok, out) = decode_lines(&[&line]).unwrap();
assert!(ok);
assert_eq!(out, vec![0u8; 2304]);
}
#[test]
fn decodes_larger_than_one_chunk() {
let line = "AAAA".repeat(B64_CHUNK / 4 * 2 + 10);
let (ok, out) = decode_lines(&[&line]).unwrap();
assert!(ok);
assert_eq!(out.len(), (B64_CHUNK / 4 * 2 + 10) * 3);
}
#[test]
fn ring_buffer_stays_bounded() {
let single = "AAAA".repeat(100); let line_count = (B64_CHUNK / 400) * 4; let (ok, out) = decode_lines(&vec![single.as_str(); line_count]).unwrap();
assert!(ok);
assert_eq!(out.len(), line_count * 100 * 3);
}
#[test]
fn empty_input_returns_empty_output() {
let dec = B64Decoder::new(Vec::new());
let out = dec.finish().unwrap();
assert!(out.is_empty());
}
#[test]
fn invalid_base64_padding_returns_error() {
let mut dec = B64Decoder::new(Vec::new());
assert!(dec.push_line("aGVs").unwrap());
assert!(dec.push_line("bG8").unwrap()); assert!(dec.finish().is_err());
}
}