use crate::{engine::Engine, DecodeError, DecodeSliceError, PAD_BYTE};
use std::{cmp, fmt, io};
pub(crate) const BUF_SIZE: usize = 1024;
const BASE64_CHUNK_SIZE: usize = 4;
const DECODED_CHUNK_SIZE: usize = 3;
pub struct DecoderReader<'e, E: Engine, R: io::Read> {
engine: &'e E,
inner: R,
b64_buffer: [u8; BUF_SIZE],
b64_offset: usize,
b64_len: usize,
decoded_chunk_buffer: [u8; DECODED_CHUNK_SIZE],
decoded_offset: usize,
decoded_len: usize,
input_consumed_len: usize,
padding_offset: Option<usize>,
}
impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("DecoderReader")
.field("b64_offset", &self.b64_offset)
.field("b64_len", &self.b64_len)
.field("decoded_chunk_buffer", &self.decoded_chunk_buffer)
.field("decoded_offset", &self.decoded_offset)
.field("decoded_len", &self.decoded_len)
.field("input_consumed_len", &self.input_consumed_len)
.field("padding_offset", &self.padding_offset)
.finish()
}
}
impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
pub fn new(reader: R, engine: &'e E) -> Self {
DecoderReader {
engine,
inner: reader,
b64_buffer: [0; BUF_SIZE],
b64_offset: 0,
b64_len: 0,
decoded_chunk_buffer: [0; DECODED_CHUNK_SIZE],
decoded_offset: 0,
decoded_len: 0,
input_consumed_len: 0,
padding_offset: None,
}
}
fn flush_decoded_buf(&mut self, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.decoded_len > 0);
debug_assert!(!buf.is_empty());
let copy_len = cmp::min(self.decoded_len, buf.len());
debug_assert!(copy_len > 0);
debug_assert!(copy_len <= self.decoded_len);
buf[..copy_len].copy_from_slice(
&self.decoded_chunk_buffer[self.decoded_offset..self.decoded_offset + copy_len],
);
self.decoded_offset += copy_len;
self.decoded_len -= copy_len;
debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE);
Ok(copy_len)
}
fn read_from_delegate(&mut self) -> io::Result<usize> {
debug_assert!(self.b64_offset + self.b64_len < BUF_SIZE);
let read = self
.inner
.read(&mut self.b64_buffer[self.b64_offset + self.b64_len..])?;
self.b64_len += read;
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
Ok(read)
}
fn decode_to_buf(&mut self, b64_len_to_decode: usize, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.b64_len >= b64_len_to_decode);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
debug_assert!(!buf.is_empty());
let b64_to_decode = &self.b64_buffer[self.b64_offset..self.b64_offset + b64_len_to_decode];
let decode_metadata = self
.engine
.internal_decode(
b64_to_decode,
buf,
self.engine.internal_decoded_len_estimate(b64_len_to_decode),
)
.map_err(|dse| match dse {
DecodeSliceError::DecodeError(de) => {
match de {
DecodeError::InvalidByte(offset, byte) => {
match (byte, self.padding_offset) {
(PAD_BYTE, Some(first_pad_offset)) => {
DecodeError::InvalidByte(first_pad_offset, PAD_BYTE)
}
_ => {
DecodeError::InvalidByte(self.input_consumed_len + offset, byte)
}
}
}
DecodeError::InvalidLength(len) => {
DecodeError::InvalidLength(self.input_consumed_len + len)
}
DecodeError::InvalidLastSymbol(offset, byte) => {
DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte)
}
DecodeError::InvalidPadding => DecodeError::InvalidPadding,
}
}
DecodeSliceError::OutputSliceTooSmall => {
unreachable!("buf is sized correctly in calling code")
}
})
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
if let Some(offset) = self.padding_offset {
if decode_metadata.decoded_len > 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
DecodeError::InvalidByte(offset, PAD_BYTE),
));
}
}
self.padding_offset = self.padding_offset.or(decode_metadata
.padding_offset
.map(|offset| self.input_consumed_len + offset));
self.input_consumed_len += b64_len_to_decode;
self.b64_offset += b64_len_to_decode;
self.b64_len -= b64_len_to_decode;
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
Ok(decode_metadata.decoded_len)
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
debug_assert!(self.b64_offset <= BUF_SIZE);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
debug_assert!(if self.b64_offset == BUF_SIZE {
self.b64_len == 0
} else {
self.b64_len <= BUF_SIZE
});
debug_assert!(if self.decoded_len == 0 {
self.decoded_offset <= DECODED_CHUNK_SIZE
} else {
self.decoded_offset < DECODED_CHUNK_SIZE
});
debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE);
debug_assert!(self.decoded_len + self.decoded_offset <= DECODED_CHUNK_SIZE);
if self.decoded_len > 0 {
self.flush_decoded_buf(buf)
} else {
let mut at_eof = false;
while self.b64_len < BASE64_CHUNK_SIZE {
self.b64_buffer
.copy_within(self.b64_offset..self.b64_offset + self.b64_len, 0);
self.b64_offset = 0;
let read = self.read_from_delegate()?;
if read == 0 {
at_eof = true;
break;
}
}
if self.b64_len == 0 {
debug_assert!(at_eof);
return Ok(0);
};
debug_assert!(if at_eof {
self.b64_len > 0
} else {
self.b64_len >= BASE64_CHUNK_SIZE
});
debug_assert_eq!(0, self.decoded_len);
if buf.len() < DECODED_CHUNK_SIZE {
let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE];
let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE);
let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?;
self.decoded_chunk_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]);
self.decoded_offset = 0;
self.decoded_len = decoded;
debug_assert!(decoded <= 3);
self.flush_decoded_buf(buf)
} else {
let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
.checked_mul(BASE64_CHUNK_SIZE)
.expect("too many chunks");
debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);
let b64_bytes_available_to_decode = if at_eof {
self.b64_len
} else {
self.b64_len - self.b64_len % 4
};
let actual_decode_len = cmp::min(
b64_bytes_that_can_decode_into_buf,
b64_bytes_available_to_decode,
);
self.decode_to_buf(actual_decode_len, buf)
}
}
}
}