use std::mem;
use io_stream::{
coroutines::read::{ReadStream, ReadStreamError, ReadStreamResult},
io::StreamIo,
};
use memchr::memmem;
use thiserror::Error;
const CR: u8 = b'\r';
const LF: u8 = b'\n';
const CRLF: [u8; 2] = [CR, LF];
const CRLF_CRLF: [u8; 4] = [CR, LF, CR, LF];
#[derive(Debug, Error)]
pub enum ReadStreamChunksError {
#[error("Received unexpected EOF")]
UnexpectedEof,
#[error("Received invalid chunk size: {0}")]
InvalidChunkSize(String),
#[error(transparent)]
ReadStream(#[from] ReadStreamError),
}
#[derive(Debug)]
pub enum ReadStreamChunksResult {
Io(StreamIo),
Err(ReadStreamChunksError),
Ok(Vec<u8>),
}
#[derive(Debug)]
enum State {
ChunkSize,
ChunkData(usize),
Trailer,
}
#[derive(Debug)]
pub struct ReadStreamChunks {
read: ReadStream,
state: State,
buffer: Vec<u8>,
body: Vec<u8>,
}
impl ReadStreamChunks {
pub fn new(read: impl Into<ReadStream>) -> Self {
Self {
read: read.into(),
state: State::ChunkSize,
buffer: Vec::new(),
body: Vec::new(),
}
}
pub fn extend(&mut self, bytes: impl IntoIterator<Item = u8>) {
self.buffer.extend(bytes);
}
pub fn resume(&mut self, mut arg: Option<StreamIo>) -> ReadStreamChunksResult {
loop {
match &mut self.state {
State::ChunkSize => {
let Some(crlf) = memmem::find(&self.buffer, &CRLF) else {
let output = match self.read.resume(arg.take()) {
ReadStreamResult::Ok(output) => output,
ReadStreamResult::Err(err) => {
return ReadStreamChunksResult::Err(err.into())
}
ReadStreamResult::Io(io) => return ReadStreamChunksResult::Io(io),
ReadStreamResult::Eof => {
return ReadStreamChunksResult::Err(
ReadStreamChunksError::UnexpectedEof,
)
}
};
self.buffer.extend(output.bytes());
self.read.replace(output.buffer);
continue;
};
let ext = memchr::memchr(b';', &self.buffer[..crlf]).unwrap_or(crlf);
let chunk_size = String::from_utf8_lossy(&self.buffer[..ext]);
let Ok(chunk_size) = usize::from_str_radix(&chunk_size, 16) else {
let chunk_size = chunk_size.to_string();
return ReadStreamChunksResult::Err(
ReadStreamChunksError::InvalidChunkSize(chunk_size),
);
};
if chunk_size == 0 {
self.buffer.drain(..crlf);
self.state = State::Trailer;
continue;
}
self.buffer.drain(..crlf + CRLF.len());
self.state = State::ChunkData(chunk_size + CRLF.len());
}
State::ChunkData(0) => {
self.body.drain(self.body.len() - CRLF.len()..);
self.state = State::ChunkSize;
}
State::ChunkData(_) if self.buffer.is_empty() => {
let output = match self.read.resume(arg.take()) {
ReadStreamResult::Ok(output) => output,
ReadStreamResult::Err(err) => {
return ReadStreamChunksResult::Err(err.into())
}
ReadStreamResult::Io(io) => return ReadStreamChunksResult::Io(io),
ReadStreamResult::Eof => {
return ReadStreamChunksResult::Err(
ReadStreamChunksError::UnexpectedEof,
)
}
};
self.buffer.extend(output.bytes());
self.read.replace(output.buffer);
}
State::ChunkData(size) => {
let min_size = self.buffer.len().min(*size);
self.body.extend(self.buffer.drain(..min_size));
*size -= min_size;
}
State::Trailer => {
let Some(0) = memmem::rfind(&self.buffer, &CRLF_CRLF) else {
let output = match self.read.resume(arg.take()) {
ReadStreamResult::Ok(output) => output,
ReadStreamResult::Err(err) => {
return ReadStreamChunksResult::Err(err.into())
}
ReadStreamResult::Io(io) => return ReadStreamChunksResult::Io(io),
ReadStreamResult::Eof => {
return ReadStreamChunksResult::Err(
ReadStreamChunksError::UnexpectedEof,
)
}
};
self.buffer.extend(output.bytes());
self.read.replace(output.buffer);
continue;
};
break ReadStreamChunksResult::Ok(mem::take(&mut self.body));
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::io::{BufReader, Read as _};
use io_stream::{
coroutines::read::ReadStream,
io::{StreamIo, StreamOutput},
};
use crate::v1_1::coroutines::read_chunks::ReadStreamChunksResult;
use super::ReadStreamChunks;
fn test(encoded: &str, decoded: &str) {
let mut reader = BufReader::new(encoded.as_bytes());
let read = ReadStream::default();
let mut http = ReadStreamChunks::new(read);
let mut arg = None;
let body = loop {
match http.resume(arg.take()) {
ReadStreamChunksResult::Ok(output) => break output,
ReadStreamChunksResult::Io(StreamIo::Read(Err(mut buffer))) => {
let bytes_count = reader.read(&mut buffer).unwrap();
let output = StreamOutput {
buffer,
bytes_count,
};
arg = Some(StreamIo::Read(Ok(output)))
}
other => unreachable!("Unexpected result: {other:?}"),
}
};
assert_eq!(body, decoded.as_bytes());
}
#[test]
fn wiki_ru() {
test(
concat!(
"9\r\n",
"chunk 1, \r\n",
"7\r\n",
"chunk 2\r\n",
"0\r\n",
"\r\n",
),
"chunk 1, chunk 2",
);
}
#[test]
fn wiki_fr() {
test(
concat!(
"27\r\n",
"Voici les données du premier morceau\r\n\r\n",
"1C\r\n",
"et voici un second morceau\r\n\r\n",
"20\r\n",
"et voici deux derniers morceaux \r\n",
"12\r\n",
"sans saut de ligne\r\n",
"0\r\n",
"\r\n",
),
concat!(
"Voici les données du premier morceau\r\n",
"et voici un second morceau\r\n",
"et voici deux derniers morceaux ",
"sans saut de ligne",
),
);
}
#[test]
fn github_frewsxcv() {
test(
"3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n",
"hello world!!!",
);
}
}