use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use futures::{Stream, StreamExt, ready};
use crate::{
error::Error,
line::{LineDecoder, LineDecoderOptions},
utils::num::{self, HexEncoder},
};
pub struct Body {
inner: BodyVariant,
size: Option<usize>,
}
impl Body {
pub fn from_stream<S>(stream: S) -> Self
where
S: Stream<Item = io::Result<Bytes>> + Send + 'static,
{
Self {
inner: BodyVariant::Stream(Box::pin(stream)),
size: None,
}
}
#[inline]
pub const fn empty() -> Self {
Self {
inner: BodyVariant::Empty,
size: Some(0),
}
}
#[inline]
pub fn size(&self) -> Option<usize> {
self.size
}
pub async fn read(self, max_size: Option<usize>) -> io::Result<Bytes> {
let mut stream = match self.inner {
BodyVariant::Empty => return Ok(Bytes::new()),
BodyVariant::Bytes(data) => {
if let Some(max_size) = max_size {
if data.len() > max_size {
return Err(io::Error::other("maximum body size exceeded"));
}
}
return Ok(data);
}
BodyVariant::Stream(stream) => stream,
};
let mut body = BytesMut::new();
while let Some(chunk) = stream.next().await.transpose()? {
body.extend_from_slice(&chunk);
if let Some(max_size) = max_size {
if body.len() > max_size {
return Err(io::Error::other("maximum body size exceeded"));
}
}
}
Ok(body.freeze())
}
pub async fn discard(mut self, max_size: Option<usize>) -> io::Result<()> {
let mut discarded = 0;
while let Some(chunk) = self.next().await.transpose()? {
discarded += chunk.len();
if let Some(max_size) = max_size {
if discarded > max_size {
return Err(io::Error::other("maximum body size exceeded"));
}
}
}
Ok(())
}
}
impl Default for Body {
#[inline]
fn default() -> Self {
Self::empty()
}
}
impl From<&'static [u8]> for Body {
#[inline]
fn from(s: &'static [u8]) -> Self {
Self::from(Bytes::from(s))
}
}
impl From<&'static str> for Body {
#[inline]
fn from(s: &'static str) -> Self {
Self::from(Bytes::from(s))
}
}
impl From<Bytes> for Body {
#[inline]
fn from(data: Bytes) -> Self {
let size = Some(data.len());
Self {
inner: BodyVariant::Bytes(data),
size,
}
}
}
impl From<BytesMut> for Body {
#[inline]
fn from(bytes: BytesMut) -> Self {
Self::from(Bytes::from(bytes))
}
}
impl From<Box<[u8]>> for Body {
#[inline]
fn from(bytes: Box<[u8]>) -> Self {
Self::from(Bytes::from(bytes))
}
}
impl From<Vec<u8>> for Body {
#[inline]
fn from(bytes: Vec<u8>) -> Self {
Self::from(Bytes::from(bytes))
}
}
impl From<String> for Body {
#[inline]
fn from(s: String) -> Self {
Self::from(Bytes::from(s))
}
}
impl Stream for Body {
type Item = io::Result<Bytes>;
#[inline]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}
enum BodyVariant {
Empty,
Bytes(Bytes),
Stream(Pin<Box<dyn Stream<Item = io::Result<Bytes>> + Send>>),
}
impl Stream for BodyVariant {
type Item = io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut *self {
Self::Empty => Poll::Ready(None),
Self::Bytes(_) => {
if let Self::Bytes(data) = std::mem::replace(&mut *self, Self::Empty) {
Poll::Ready(Some(Ok(data)))
} else {
Poll::Ready(None)
}
}
Self::Stream(stream) => stream.poll_next_unpin(cx),
}
}
}
pub trait MessageBodyDecoder {
fn is_complete(&self) -> bool;
fn decode(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error>;
fn decode_eof(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error>;
}
pub struct SimpleBodyDecoder {
complete: bool,
}
impl SimpleBodyDecoder {
#[inline]
pub const fn new() -> Self {
Self { complete: false }
}
}
impl Default for SimpleBodyDecoder {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl MessageBodyDecoder for SimpleBodyDecoder {
#[inline]
fn is_complete(&self) -> bool {
self.complete
}
fn decode(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
if self.is_complete() {
return Ok(None);
}
let data = data.split();
if data.is_empty() {
Ok(None)
} else {
Ok(Some(data.freeze()))
}
}
#[inline]
fn decode_eof(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
let data = self.decode(data);
self.complete = true;
data
}
}
pub struct FixedSizeBodyDecoder {
expected: usize,
}
impl FixedSizeBodyDecoder {
#[inline]
pub const fn new(expected: usize) -> Self {
Self { expected }
}
}
impl MessageBodyDecoder for FixedSizeBodyDecoder {
#[inline]
fn is_complete(&self) -> bool {
self.expected == 0
}
fn decode(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
if self.is_complete() {
return Ok(None);
}
let take = self.expected.min(data.len());
self.expected -= take;
let data = data.split_to(take);
if data.is_empty() {
Ok(None)
} else {
Ok(Some(data.freeze()))
}
}
fn decode_eof(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
if let Some(chunk) = self.decode(data)? {
Ok(Some(chunk))
} else if self.is_complete() {
Ok(None)
} else {
Err(Error::from_static_msg("incomplete body"))
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum ChunkedDecoderState {
ChunkHeader,
ChunkBody,
ChunkBodyDelimiter,
TrailerPart,
Completed,
}
pub struct ChunkedBodyDecoder {
state: ChunkedDecoderState,
line_decoder: LineDecoder,
expected: usize,
}
impl ChunkedBodyDecoder {
#[inline]
pub fn new(max_line_length: Option<usize>) -> Self {
let options = LineDecoderOptions::new()
.cr(false)
.lf(false)
.crlf(true)
.require_terminator(false)
.max_line_length(max_line_length);
let decoder = LineDecoder::new(options);
Self {
state: ChunkedDecoderState::ChunkHeader,
line_decoder: decoder,
expected: 0,
}
}
fn decoding_step(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
match self.state {
ChunkedDecoderState::ChunkHeader => self.decode_chunk_header(data),
ChunkedDecoderState::ChunkBody => self.decode_chunk_body(data),
ChunkedDecoderState::ChunkBodyDelimiter => self.decode_chunk_body_delimiter(data),
ChunkedDecoderState::TrailerPart => self.decode_trailer_part(data),
ChunkedDecoderState::Completed => Ok(None),
}
}
fn decode_chunk_header(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
if let Some(header) = self.line_decoder.decode(data)? {
let end = header
.iter()
.position(|&b| b == b';')
.unwrap_or(header.len());
let size = num::decode_hex(&header[..end])?;
self.expected = size;
if size > 0 {
self.state = ChunkedDecoderState::ChunkBody;
} else {
self.state = ChunkedDecoderState::TrailerPart;
}
}
Ok(None)
}
fn decode_chunk_body(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
let take = self.expected.min(data.len());
self.expected -= take;
let data = data.split_to(take);
if self.expected == 0 {
self.state = ChunkedDecoderState::ChunkBodyDelimiter;
}
if data.is_empty() {
Ok(None)
} else {
Ok(Some(data.freeze()))
}
}
fn decode_chunk_body_delimiter(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
if self.line_decoder.decode(data)?.is_some() {
self.state = ChunkedDecoderState::ChunkHeader;
}
Ok(None)
}
fn decode_trailer_part(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
if let Some(line) = self.line_decoder.decode(data)? {
if line.is_empty() {
self.state = ChunkedDecoderState::Completed;
}
}
Ok(None)
}
}
impl MessageBodyDecoder for ChunkedBodyDecoder {
#[inline]
fn is_complete(&self) -> bool {
self.state == ChunkedDecoderState::Completed
}
fn decode(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
while !self.is_complete() && !data.is_empty() {
let res = self.decoding_step(data)?;
if res.is_some() {
return Ok(res);
}
}
Ok(None)
}
fn decode_eof(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, Error> {
if let Some(chunk) = self.decode(data)? {
Ok(Some(chunk))
} else if self.is_complete() {
Ok(None)
} else {
Err(Error::from_static_msg("incomplete body"))
}
}
}
pin_project_lite::pin_project! {
pub struct ChunkedStream<S> {
#[pin]
stream: Option<S>,
chunk_buffer: BytesMut,
hex_encoder: HexEncoder,
}
}
impl<S> ChunkedStream<S> {
pub fn new(stream: S) -> Self {
Self {
stream: Some(stream),
chunk_buffer: BytesMut::new(),
hex_encoder: HexEncoder::new(),
}
}
}
impl<S, E> Stream for ChunkedStream<S>
where
S: Stream<Item = Result<Bytes, E>>,
{
type Item = Result<Bytes, E>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let Some(stream) = this.stream.as_mut().as_pin_mut() else {
return Poll::Ready(None);
};
match ready!(stream.poll_next(cx)) {
Some(Ok(data)) => {
let encoded_size = this.hex_encoder.encode(data.len());
let chunk_size = encoded_size.len() + data.len() + 4;
this.chunk_buffer.reserve(chunk_size);
this.chunk_buffer.extend_from_slice(encoded_size);
this.chunk_buffer.extend_from_slice(b"\r\n");
this.chunk_buffer.extend_from_slice(&data);
this.chunk_buffer.extend_from_slice(b"\r\n");
let chunk = this.chunk_buffer.split();
Poll::Ready(Some(Ok(chunk.freeze())))
}
Some(Err(err)) => {
this.stream.set(None);
Poll::Ready(Some(Err(err)))
}
None => {
let chunk = Bytes::from("0\r\n\r\n");
this.stream.set(None);
Poll::Ready(Some(Ok(chunk)))
}
}
}
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use super::{ChunkedBodyDecoder, FixedSizeBodyDecoder, MessageBodyDecoder, SimpleBodyDecoder};
#[test]
fn test_simple_body_decoder() {
let mut decoder = SimpleBodyDecoder::new();
let mut data = BytesMut::from("foo");
assert!(!decoder.is_complete());
let res = decoder.decode(&mut data).unwrap().unwrap();
assert!(data.is_empty());
assert_eq!(res, "foo");
assert!(!decoder.is_complete());
let mut data = BytesMut::from("bar");
let res = decoder.decode_eof(&mut data).unwrap().unwrap();
assert!(data.is_empty());
assert_eq!(res, "bar");
assert!(decoder.is_complete());
let mut data = BytesMut::from("abcd");
let res = decoder.decode(&mut data).unwrap();
assert_eq!(data, "abcd");
assert_eq!(res, None);
}
#[test]
fn test_fixed_size_body_decoder() {
let decoder = FixedSizeBodyDecoder::new(0);
assert!(decoder.is_complete());
let mut decoder = FixedSizeBodyDecoder::new(10);
assert!(!decoder.is_complete());
let mut data = BytesMut::from("1234");
let res = decoder.decode(&mut data).unwrap().unwrap();
assert!(data.is_empty());
assert_eq!(res, "1234");
assert!(!decoder.is_complete());
let mut data = BytesMut::from("123456789");
let res = decoder.decode(&mut data).unwrap().unwrap();
assert_eq!(data, "789");
assert_eq!(res, "123456");
assert!(decoder.is_complete());
let res = decoder.decode(&mut data).unwrap();
assert_eq!(data, "789");
assert_eq!(res, None);
}
#[test]
fn test_chunked_body_decoder() {
let data = "a;foo=bar\r\n".to_string()
+ "0123456789 and some garbage\r\n"
+ "0\r\n"
+ "and\r\n"
+ "some\r\n"
+ "garbage again\r\n"
+ "\r\n"
+ "and this is a new message";
let mut data = BytesMut::from(data.as_str());
let mut decoder = ChunkedBodyDecoder::new(Some(256));
assert!(!decoder.is_complete());
let res = decoder.decode(&mut data).unwrap().unwrap();
assert!(!decoder.is_complete());
assert_eq!(res, "0123456789");
assert_eq!(
data,
" and some garbage\r\n".to_string()
+ "0\r\n"
+ "and\r\n"
+ "some\r\n"
+ "garbage again\r\n"
+ "\r\n"
+ "and this is a new message"
);
let res = decoder.decode(&mut data).unwrap();
assert!(decoder.is_complete());
assert!(res.is_none());
assert_eq!(data, "and this is a new message");
let res = decoder.decode(&mut data).unwrap();
assert!(decoder.is_complete());
assert!(res.is_none());
assert_eq!(data, "and this is a new message");
}
#[test]
fn test_chunked_decoder_on_ivalid_chunk_size() {
let mut data = BytesMut::from("ggg\r\n0123456789\r\n0\r\n\r\n");
let mut decoder = ChunkedBodyDecoder::new(Some(256));
let res = decoder.decode(&mut data);
assert!(res.is_err());
}
#[test]
fn test_chunked_body_decoder_on_line_length_exceeded() {
let mut data = BytesMut::from("5;very_long_attribute=val\r\n01234\r\n0\r\n\r\n");
let mut decoder = ChunkedBodyDecoder::new(Some(5));
let res = decoder.decode(&mut data);
assert!(res.is_err());
}
}