use std::io::{BufRead, Read, Take};
use crate::{
header::{HeaderMap, HeaderParser},
io::BufReadMoreExt,
nomutil::NomParseError,
};
use super::HTTPError;
pub struct ChunkedDecoder<R: BufRead> {
stream: Option<R>,
data_reader: Option<ChunkDataReader<R>>,
state: DecoderState,
buffer: Vec<u8>,
buffer_limit: u64,
chunk_length: u64,
}
impl<R> ChunkedDecoder<R>
where
R: BufRead,
{
pub fn new(stream: R) -> Self {
Self {
stream: Some(stream),
data_reader: None,
state: DecoderState::StartOfLine,
buffer: Vec::new(),
buffer_limit: 32768,
chunk_length: 0,
}
}
pub fn get_ref(&self) -> &R {
self.stream
.as_ref()
.unwrap_or_else(|| self.data_reader.as_ref().unwrap().stream.get_ref())
}
pub fn get_mut(&mut self) -> &mut R {
self.stream
.as_mut()
.unwrap_or_else(|| self.data_reader.as_mut().unwrap().stream.get_mut())
}
pub fn into_inner(self) -> R {
self.stream
.unwrap_or_else(|| self.data_reader.unwrap().stream.into_inner())
}
pub fn begin_chunk(&mut self) -> Result<ChunkMetadata, HTTPError> {
tracing::trace!("begin_chunk");
assert!(self.state == DecoderState::StartOfLine);
self.buffer.clear();
self.stream
.as_mut()
.unwrap()
.read_limit_until(b'\n', &mut self.buffer, 4096)?;
let metadata = parse_chunk_metadata(&self.buffer)?;
self.chunk_length = metadata.length;
self.state = DecoderState::EndOfLine;
Ok(metadata)
}
pub fn read_data(&mut self) -> &mut ChunkDataReader<R> {
if self.stream.is_some() {
self.set_up_chunk_data_reader();
}
self.data_reader.as_mut().unwrap()
}
fn set_up_chunk_data_reader(&mut self) {
tracing::trace!(chunk_length = self.chunk_length, "set_up_chunk_data_reader");
assert!(self.state == DecoderState::EndOfLine);
self.state = DecoderState::InBody;
let stream = self.stream.take().unwrap().take(self.chunk_length);
let reader = ChunkDataReader {
stream,
amount_read: 0,
};
self.data_reader = Some(reader);
}
pub fn end_chunk(&mut self) -> Result<(), HTTPError> {
tracing::trace!("end_chunk");
assert!(self.state == DecoderState::InBody);
let data_reader = self.data_reader.take().unwrap();
if data_reader.amount_read != self.chunk_length {
return Err(HTTPError::UnexpectedEnd);
}
self.stream = Some(data_reader.stream.into_inner());
if self.chunk_length == 0 {
self.state = DecoderState::StartOfTrailer;
} else {
self.read_chunk_deliminator()?;
self.state = DecoderState::StartOfLine;
}
Ok(())
}
fn read_chunk_deliminator(&mut self) -> Result<(), HTTPError> {
tracing::trace!("read_chunk_deliminator");
self.buffer.clear();
self.stream
.as_mut()
.unwrap()
.read_limit_until(b'\n', &mut self.buffer, 2)?;
Ok(())
}
pub fn read_trailer(&mut self) -> Result<HeaderMap, HTTPError> {
tracing::trace!("read_trailer");
assert!(self.state == DecoderState::StartOfTrailer);
self.buffer.clear();
let stream = self.stream.as_mut().unwrap();
crate::header::read_until_boundary(stream, &mut self.buffer, self.buffer_limit)?;
let parser = HeaderParser::new();
let header_map = parser
.parse_header(crate::stringutil::trim_trailing_crlf(&self.buffer))
.map_err(|error| HTTPError::InvalidEncoding {
source: Some(Box::new(error)),
})?;
self.state = DecoderState::EndOfTrailer;
Ok(header_map)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DecoderState {
StartOfLine,
EndOfLine,
InBody,
StartOfTrailer,
EndOfTrailer,
}
pub struct ChunkDataReader<R: BufRead> {
stream: Take<R>,
amount_read: u64,
}
impl<R: BufRead> Read for ChunkDataReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let amount = self.stream.read(buf)?;
self.amount_read += amount as u64;
Ok(amount)
}
}
impl<R: BufRead> BufRead for ChunkDataReader<R> {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
self.stream.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
#[derive(Debug, Clone)]
pub struct ChunkMetadata {
pub length: u64,
pub parameters: Vec<(String, String)>,
}
pub fn parse_chunk_metadata(line: &[u8]) -> Result<ChunkMetadata, HTTPError> {
if let Ok(result) = super::pc::parse_chunk_metadata(line) {
return Ok(ChunkMetadata {
length: result.0,
parameters: result.1,
});
};
match super::pc::parse_chunk_metadata_fallback(line) {
Ok(size) => Ok(ChunkMetadata {
length: size,
parameters: Vec::new(),
}),
Err(error) => Err(HTTPError::InvalidEncoding {
source: Some(Box::new(NomParseError::from_nom(line, &error))),
}),
}
}
pub struct ChunkedReader<R: BufRead> {
inner: ChunkedDecoder<R>,
state: ChunkedReaderState,
chunk_size: u64,
chunk_amount_read: u64,
}
impl<R: BufRead> ChunkedReader<R> {
pub fn new(stream: R) -> Self {
Self {
inner: ChunkedDecoder::new(stream),
state: ChunkedReaderState::Start,
chunk_size: 0,
chunk_amount_read: 0,
}
}
pub fn get_ref(&self) -> &R {
self.inner.get_ref()
}
pub fn get_mut(&mut self) -> &mut R {
self.inner.get_mut()
}
pub fn into_inner(self) -> R {
self.inner.into_inner()
}
fn remap_error(error: HTTPError) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::Other, error)
}
fn read_metadata(&mut self) -> std::io::Result<()> {
tracing::trace!("read_metadata");
let metadata = self.inner.begin_chunk().map_err(Self::remap_error)?;
self.chunk_size = metadata.length;
self.chunk_amount_read = 0;
Ok(())
}
fn read_0_chunk_and_trailer(&mut self) -> std::io::Result<()> {
tracing::trace!("read_0_chunk_and_trailer");
let reader = self.inner.read_data();
let mut temp = [0u8; 1];
let _amount = reader.read(&mut temp)?;
self.inner.end_chunk().map_err(Self::remap_error)?;
self.inner.read_trailer().map_err(Self::remap_error)?;
Ok(())
}
}
impl<R: BufRead> Read for ChunkedReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if buf.is_empty() || self.state == ChunkedReaderState::Finished {
return Ok(0);
}
loop {
if self.state == ChunkedReaderState::Start {
self.read_metadata()?;
if self.chunk_size == 0 {
self.read_0_chunk_and_trailer()?;
tracing::trace!("new state Finished");
self.state = ChunkedReaderState::Finished;
return Ok(0);
} else {
tracing::trace!("new state ReadingData");
self.state = ChunkedReaderState::ReadingData;
}
};
if self.state == ChunkedReaderState::ReadingData {
let amount = self.inner.read_data().read(buf)?;
self.chunk_amount_read += amount as u64;
if amount == 0 && self.chunk_amount_read == self.chunk_size {
tracing::trace!("new state Start");
self.inner.end_chunk().map_err(Self::remap_error)?;
self.state = ChunkedReaderState::Start;
} else {
return Ok(amount);
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ChunkedReaderState {
Start,
ReadingData,
Finished,
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
#[test]
fn test_decoder() {
let body = Cursor::new(b"3\r\nabc\r\n5\r\nhello\r\n0\r\nk1:v2\r\n\r\n");
let mut reader = ChunkedDecoder::new(body);
fn read_chunk<R: BufRead>(reader: &mut ChunkedDecoder<R>, expected: &[u8]) {
let mut buffer = Vec::new();
let metadata = reader.begin_chunk().unwrap();
assert_eq!(metadata.length, expected.len() as u64);
let data_reader = reader.read_data();
data_reader.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer, expected);
reader.end_chunk().unwrap();
}
read_chunk(&mut reader, b"abc");
read_chunk(&mut reader, b"hello");
read_chunk(&mut reader, b"");
reader.read_trailer().unwrap();
}
#[test]
fn test_parse_chunk_metadata() {
let metadata = parse_chunk_metadata(b"0a\n").unwrap();
assert_eq!(metadata.length, 10);
let metadata = parse_chunk_metadata(b"0a;k1=v1\n").unwrap();
assert_eq!(metadata.length, 10);
assert_eq!(metadata.parameters[0].0, "k1");
assert_eq!(metadata.parameters[0].1, "v1");
}
#[test]
fn test_reader() {
let body = Cursor::new(b"3\r\nabc\r\n5\r\nhello\r\n0\r\nk1:v2\r\n\r\n");
let mut reader = ChunkedReader::new(body);
let mut output = Vec::new();
reader.read_to_end(&mut output).unwrap();
assert_eq!(output, b"abchello");
}
}