use crate::{decode_config_slice, Config, DecodeError};
use std::io::Read;
use std::{cmp, fmt, io};
pub(crate) const BUF_SIZE: usize = 1024;
const BASE64_CHUNK_SIZE: usize = 4;
const DECODED_CHUNK_SIZE: usize = 3;
pub struct DecoderReader<'a, R: 'a + io::Read> {
config: Config,
r: &'a mut R,
b64_buffer: [u8; BUF_SIZE],
b64_offset: usize,
b64_len: usize,
decoded_buffer: [u8; 3],
decoded_offset: usize,
decoded_len: usize,
total_b64_decoded: usize,
}
impl<'a, R: io::Read> fmt::Debug for DecoderReader<'a, R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("DecoderReader")
.field("config", &self.config)
.field("b64_offset", &self.b64_offset)
.field("b64_len", &self.b64_len)
.field("decoded_buffer", &self.decoded_buffer)
.field("decoded_offset", &self.decoded_offset)
.field("decoded_len", &self.decoded_len)
.field("total_b64_decoded", &self.total_b64_decoded)
.finish()
}
}
impl<'a, R: io::Read> DecoderReader<'a, R> {
pub fn new(r: &'a mut R, config: Config) -> Self {
DecoderReader {
config,
r,
b64_buffer: [0; BUF_SIZE],
b64_offset: 0,
b64_len: 0,
decoded_buffer: [0; DECODED_CHUNK_SIZE],
decoded_offset: 0,
decoded_len: 0,
total_b64_decoded: 0,
}
}
fn flush_decoded_buf(&mut self, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.decoded_len > 0);
debug_assert!(buf.len() > 0);
let copy_len = cmp::min(self.decoded_len, buf.len());
debug_assert!(copy_len > 0);
debug_assert!(copy_len <= self.decoded_len);
buf[..copy_len].copy_from_slice(
&self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len],
);
self.decoded_offset += copy_len;
self.decoded_len -= copy_len;
debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE);
Ok(copy_len)
}
fn read_from_delegate(&mut self) -> io::Result<usize> {
debug_assert!(self.b64_offset + self.b64_len < BUF_SIZE);
let read = self
.r
.read(&mut self.b64_buffer[self.b64_offset + self.b64_len..])?;
self.b64_len += read;
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
return Ok(read);
}
fn decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.b64_len >= num_bytes);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
debug_assert!(buf.len() > 0);
let decoded = decode_config_slice(
&self.b64_buffer[self.b64_offset..self.b64_offset + num_bytes],
self.config,
&mut buf[..],
)
.map_err(|e| match e {
DecodeError::InvalidByte(offset, byte) => {
DecodeError::InvalidByte(self.total_b64_decoded + offset, byte)
}
DecodeError::InvalidLength => DecodeError::InvalidLength,
DecodeError::InvalidLastSymbol(offset, byte) => {
DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte)
}
})
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
self.total_b64_decoded += num_bytes;
self.b64_offset += num_bytes;
self.b64_len -= num_bytes;
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
Ok(decoded)
}
}
impl<'a, R: Read> Read for DecoderReader<'a, R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.len() == 0 {
return Ok(0);
}
debug_assert!(self.b64_offset <= BUF_SIZE);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
debug_assert!(if self.b64_offset == BUF_SIZE {
self.b64_len == 0
} else {
self.b64_len <= BUF_SIZE
});
debug_assert!(if self.decoded_len == 0 {
self.decoded_offset <= DECODED_CHUNK_SIZE
} else {
self.decoded_offset < DECODED_CHUNK_SIZE
});
debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE);
debug_assert!(self.decoded_len + self.decoded_offset <= DECODED_CHUNK_SIZE);
if self.decoded_len > 0 {
self.flush_decoded_buf(buf)
} else {
let mut at_eof = false;
while self.b64_len < BASE64_CHUNK_SIZE {
let mut memmove_buf = [0_u8; BASE64_CHUNK_SIZE];
memmove_buf[..self.b64_len].copy_from_slice(
&self.b64_buffer[self.b64_offset..self.b64_offset + self.b64_len],
);
self.b64_buffer[0..self.b64_len].copy_from_slice(&memmove_buf[..self.b64_len]);
self.b64_offset = 0;
let read = self.read_from_delegate()?;
if read == 0 {
at_eof = true;
break;
}
}
if self.b64_len == 0 {
debug_assert!(at_eof);
return Ok(0);
};
debug_assert!(if at_eof {
self.b64_len > 0
} else {
self.b64_len >= BASE64_CHUNK_SIZE
});
debug_assert_eq!(0, self.decoded_len);
if buf.len() < DECODED_CHUNK_SIZE {
let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE];
let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE);
let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?;
self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]);
self.decoded_offset = 0;
self.decoded_len = decoded;
debug_assert!(decoded <= 3);
self.flush_decoded_buf(buf)
} else {
let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
.checked_mul(BASE64_CHUNK_SIZE)
.expect("too many chunks");
debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);
let b64_bytes_available_to_decode = if at_eof {
self.b64_len
} else {
self.b64_len - self.b64_len % 4
};
let actual_decode_len = cmp::min(
b64_bytes_that_can_decode_into_buf,
b64_bytes_available_to_decode,
);
self.decode_to_buf(actual_decode_len, buf)
}
}
}
}