use core::{fmt, mem};
use std::io;
use tracing::{trace, warn};
use crate::{
bytes::{Buf, BufMut, Bytes, BytesMut},
http::header::{HeaderMap, HeaderName, HeaderValue},
};
use super::{buf_write::H1BufWrite, error::ProtoError};
const TRAILER_MAX_HEADER_SIZE: usize = 1024 * 16;
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum TransferCoding {
Eof,
Corrupted,
Length(u64),
DecodeChunked {
state: ChunkedState,
size: u64,
trailers: Trailers,
},
EncodeChunked,
Upgrade,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Trailers {
buf: Option<Box<BytesMut>>,
len: usize,
limit: usize,
size_limit: usize,
}
impl Trailers {
pub fn new(header_limit: usize) -> Self {
Self {
buf: None,
len: 0,
limit: header_limit,
size_limit: TRAILER_MAX_HEADER_SIZE,
}
}
fn try_put(&mut self, byte: u8) -> io::Result<()> {
if self.buf.is_some() {
self.put(byte)?;
}
Ok(())
}
fn put(&mut self, byte: u8) -> io::Result<()> {
if let Some(buf) = self.buf.as_deref_mut() {
buf.put_u8(byte);
if buf.len() > self.size_limit {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"chunk trailers bytes over limit",
));
}
} else {
let mut buf = BytesMut::with_capacity(64);
buf.put_u8(byte);
self.buf = Some(Box::new(buf));
};
Ok(())
}
fn incr_len(&mut self) -> io::Result<()> {
self.len += 1;
if self.len > self.limit {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"chunk trailers count overflow",
));
}
Ok(())
}
fn take(&mut self) -> Option<Self> {
self.buf.is_some().then(|| mem::replace(self, Trailers::new(0)))
}
fn decode(self) -> io::Result<HeaderMap> {
let buf = self.buf.expect("trailer buf must be initialized");
let mut headers = vec![httparse::EMPTY_HEADER; self.len];
match httparse::parse_headers(&buf, &mut headers) {
Ok(httparse::Status::Complete((_, parsed))) => {
let mut map = HeaderMap::with_capacity(parsed.len());
for header in parsed {
let name = HeaderName::from_bytes(header.name.as_bytes())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid trailer header name"))?;
let value = HeaderValue::from_bytes(header.value)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid trailer header value"))?;
map.append(name, value);
}
Ok(map)
}
Ok(httparse::Status::Partial) => {
Err(io::Error::new(io::ErrorKind::InvalidInput, "partial trailer headers"))
}
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
}
}
}
impl TransferCoding {
#[inline]
pub const fn eof() -> Self {
Self::Eof
}
#[inline]
pub const fn length(len: u64) -> Self {
Self::Length(len)
}
#[inline]
pub fn decode_chunked(header_limit: usize) -> Self {
Self::DecodeChunked {
state: ChunkedState::Size,
size: 0,
trailers: Trailers::new(header_limit),
}
}
#[inline]
pub const fn encode_chunked() -> Self {
Self::EncodeChunked
}
#[inline]
pub const fn upgrade() -> Self {
Self::Upgrade
}
#[inline]
pub fn is_eof(&self) -> bool {
match self {
Self::Eof => true,
Self::EncodeChunked => unreachable!("TransferCoding can't decide eof state when encoding chunked data"),
_ => false,
}
}
#[inline]
pub fn is_upgrade(&self) -> bool {
matches!(self, Self::Upgrade)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ChunkedState {
Size,
SizeLws,
Extension,
SizeLf,
Body,
BodyCr,
BodyLf,
Trailer,
TrailerLf,
EndCr,
EndLf,
End,
}
macro_rules! byte (
($rdr:ident) => ({
if $rdr.len() > 0 {
let b = $rdr[0];
$rdr.advance(1);
b
} else {
return Ok(None);
}
})
);
impl ChunkedState {
pub fn step(
&mut self,
body: &mut BytesMut,
size: &mut u64,
buf: &mut Option<Bytes>,
trailers: &mut Trailers,
) -> io::Result<Option<Self>> {
match *self {
Self::Size => Self::read_size(body, size),
Self::SizeLws => Self::read_size_lws(body),
Self::Extension => Self::read_extension(body),
Self::SizeLf => Self::read_size_lf(body, size),
Self::Body => Self::read_body(body, size, buf),
Self::BodyCr => Self::read_body_cr(body),
Self::BodyLf => Self::read_body_lf(body),
Self::Trailer => Self::read_trailer(body, trailers),
Self::TrailerLf => Self::read_trailer_lf(body, trailers),
Self::EndCr => Self::read_end_cr(body, trailers),
Self::EndLf => Self::read_end_lf(body, trailers),
Self::End => Ok(Some(Self::End)),
}
}
fn read_size(rdr: &mut BytesMut, size: &mut u64) -> io::Result<Option<Self>> {
macro_rules! or_overflow {
($e:expr) => (
match $e {
Some(val) => val,
None => return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid chunk size: overflow",
)),
}
)
}
let radix = 16;
match byte!(rdr) {
b @ b'0'..=b'9' => {
*size = or_overflow!(size.checked_mul(radix));
*size = or_overflow!(size.checked_add((b - b'0') as u64));
}
b @ b'a'..=b'f' => {
*size = or_overflow!(size.checked_mul(radix));
*size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
}
b @ b'A'..=b'F' => {
*size = or_overflow!(size.checked_mul(radix));
*size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
}
b'\t' | b' ' => return Ok(Some(ChunkedState::SizeLws)),
b';' => return Ok(Some(ChunkedState::Extension)),
b'\r' => return Ok(Some(ChunkedState::SizeLf)),
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size",
));
}
}
Ok(Some(ChunkedState::Size))
}
fn read_size_lws(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
match byte!(rdr) {
b'\t' | b' ' => Ok(Some(Self::SizeLws)),
b';' => Ok(Some(Self::Extension)),
b'\r' => Ok(Some(Self::SizeLf)),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space",
)),
}
}
fn read_extension(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
match byte!(rdr) {
b'\r' => Ok(Some(Self::SizeLf)),
b'\n' => Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid chunk extension contains newline",
)),
_ => Ok(Some(Self::Extension)), }
}
fn read_size_lf(rdr: &mut BytesMut, size: &u64) -> io::Result<Option<Self>> {
match byte!(rdr) {
b'\n' if *size > 0 => Ok(Some(Self::Body)),
b'\n' if *size == 0 => Ok(Some(Self::EndCr)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk size LF")),
}
}
fn read_body(rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option<Bytes>) -> io::Result<Option<Self>> {
if rdr.is_empty() {
Ok(None)
} else {
*buf = Some(bounded_split(rem, rdr));
if *rem > 0 {
Ok(Some(Self::Body))
} else {
Ok(Some(Self::BodyCr))
}
}
}
fn read_body_cr(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
match byte!(rdr) {
b'\r' => Ok(Some(Self::BodyLf)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body CR")),
}
}
fn read_body_lf(rdr: &mut BytesMut) -> io::Result<Option<Self>> {
match byte!(rdr) {
b'\n' => Ok(Some(Self::Size)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body LF")),
}
}
fn read_trailer(rdr: &mut BytesMut, trailers: &mut Trailers) -> io::Result<Option<Self>> {
trace!(target: "h1_decode", "read_trailer");
let byte = byte!(rdr);
trailers.put(byte)?;
match byte {
b'\r' => Ok(Some(Self::TrailerLf)),
_ => Ok(Some(Self::Trailer)),
}
}
fn read_trailer_lf(rdr: &mut BytesMut, trailers: &mut Trailers) -> io::Result<Option<Self>> {
let byte = byte!(rdr);
match byte {
b'\n' => {
trailers.incr_len()?;
trailers.put(byte)?;
Ok(Some(Self::EndCr))
}
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid trailer end LF")),
}
}
fn read_end_cr(rdr: &mut BytesMut, trailers: &mut Trailers) -> io::Result<Option<Self>> {
let byte = byte!(rdr);
match byte {
b'\r' => {
trailers.try_put(byte)?;
Ok(Some(Self::EndLf))
}
_ => {
trailers.put(byte)?;
Ok(Some(Self::Trailer))
}
}
}
fn read_end_lf(rdr: &mut BytesMut, trailers: &mut Trailers) -> io::Result<Option<Self>> {
let byte = byte!(rdr);
match byte {
b'\n' => {
trailers.try_put(byte)?;
Ok(Some(Self::End))
}
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end LF")),
}
}
}
impl TransferCoding {
pub fn try_set(&mut self, other: Self) -> Result<(), ProtoError> {
match (&self, &other) {
(TransferCoding::Upgrade, TransferCoding::Upgrade) | (_, TransferCoding::Length(0)) => Ok(()),
(TransferCoding::Upgrade, _)
| (TransferCoding::DecodeChunked { .. }, _)
| (TransferCoding::Length(..), _) => Err(ProtoError::HeaderName),
_ => {
*self = other;
Ok(())
}
}
}
#[inline]
pub fn set_eof(&mut self) {
*self = Self::Eof;
}
#[inline]
pub fn set_corrupted(&mut self) {
*self = Self::Corrupted;
}
pub fn encode<W>(&mut self, mut bytes: Bytes, buf: &mut W)
where
W: H1BufWrite,
{
if bytes.is_empty() {
return;
}
match *self {
Self::Upgrade => buf.write_buf_bytes(bytes),
Self::EncodeChunked => buf.write_buf_bytes_chunked(bytes),
Self::Length(ref mut rem) => {
let len = bytes.len() as u64;
if *rem >= len {
buf.write_buf_bytes(bytes);
*rem -= len;
} else {
let rem = mem::replace(rem, 0u64);
buf.write_buf_bytes(bytes.split_to(rem as usize));
}
}
Self::Eof => warn!(target: "h1_encode", "TransferCoding::Eof should not encode response body"),
_ => unreachable!(),
}
}
pub fn encode_eof<W>(&mut self, trailers: Option<HeaderMap>, buf: &mut W)
where
W: H1BufWrite,
{
match *self {
Self::Eof | Self::Upgrade | Self::Length(0) => {}
Self::EncodeChunked => match trailers {
Some(trailers) => buf.write_buf_trailers(trailers),
None => buf.write_buf_static(b"0\r\n\r\n"),
},
Self::Length(n) => unreachable!("UnexpectedEof for Length Body with {} remaining", n),
_ => unreachable!(),
}
}
pub fn decode(&mut self, src: &mut BytesMut) -> ChunkResult {
match *self {
Self::Length(0)
| Self::DecodeChunked {
state: ChunkedState::End,
..
} => {
*self = Self::Eof;
ChunkResult::OnEof
}
Self::Eof => ChunkResult::AlreadyEof,
Self::Corrupted => ChunkResult::Corrupted,
ref _this if src.is_empty() => ChunkResult::InsufficientData,
Self::Length(ref mut rem) => ChunkResult::Ok(bounded_split(rem, src)),
Self::Upgrade => ChunkResult::Ok(src.split().freeze()),
Self::DecodeChunked {
ref mut state,
ref mut size,
ref mut trailers,
} => {
loop {
let mut buf = None;
*state = match state.step(src, size, &mut buf, trailers) {
Ok(Some(state)) => state,
Ok(None) => return ChunkResult::InsufficientData,
Err(e) => return ChunkResult::Err(e),
};
if matches!(state, ChunkedState::End) {
if let Some(trailers) = trailers.take() {
match trailers.decode() {
Ok(headers) => return ChunkResult::Trailers(headers),
Err(e) => return ChunkResult::Err(e),
}
}
return self.decode(src);
}
if let Some(buf) = buf {
return ChunkResult::Ok(buf);
}
}
}
_ => unreachable!(),
}
}
}
#[derive(Debug)]
pub enum ChunkResult {
Ok(Bytes),
Trailers(HeaderMap),
Err(io::Error),
InsufficientData,
OnEof,
AlreadyEof,
Corrupted,
}
impl fmt::Display for ChunkResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::Ok(_) => f.write_str("chunked data."),
Self::Trailers(_) => f.write_str("trailer headers."),
Self::Err(ref e) => fmt::Display::fmt(e, f),
Self::InsufficientData => f.write_str("no sufficient data. More input bytes required."),
Self::OnEof => f.write_str("coder reached EOF state. no more chunk can be produced."),
Self::AlreadyEof => f.write_str("coder already reached EOF state. no more chunk can be produced."),
Self::Corrupted => f.write_str("coder corrupted. can not be used anymore."),
}
}
}
impl From<io::Error> for ChunkResult {
fn from(e: io::Error) -> Self {
Self::Err(e)
}
}
fn bounded_split(rem: &mut u64, buf: &mut BytesMut) -> Bytes {
let len = buf.len() as u64;
if *rem >= len {
*rem -= len;
buf.split().freeze()
} else {
let rem = mem::replace(rem, 0);
buf.split_to(rem as usize).freeze()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_read_chunk_size() {
use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
fn read(s: &str) -> u64 {
let mut state = ChunkedState::Size;
let rdr = &mut BytesMut::from(s);
let mut size = 0;
loop {
let result = state.step(rdr, &mut size, &mut None, &mut Trailers::new(64));
state = result.unwrap_or_else(|_| panic!("read_size failed for {s:?}")).unwrap();
if state == ChunkedState::Body || state == ChunkedState::EndCr {
break;
}
}
size
}
fn read_err(s: &str, expected_err: io::ErrorKind) {
let mut state = ChunkedState::Size;
let rdr = &mut BytesMut::from(s);
let mut size = 0;
loop {
let result = state.step(rdr, &mut size, &mut None, &mut Trailers::new(64));
state = match result {
Ok(Some(s)) => s,
Ok(None) => return assert_eq!(expected_err, UnexpectedEof),
Err(e) => {
assert_eq!(
expected_err,
e.kind(),
"Reading {:?}, expected {:?}, but got {:?}",
s,
expected_err,
e.kind()
);
return;
}
};
if state == ChunkedState::Body || state == ChunkedState::End {
panic!("Was Ok. Expected Err for {s:?}");
}
}
}
assert_eq!(1, read("1\r\n"));
assert_eq!(1, read("01\r\n"));
assert_eq!(0, read("0\r\n"));
assert_eq!(0, read("00\r\n"));
assert_eq!(10, read("A\r\n"));
assert_eq!(10, read("a\r\n"));
assert_eq!(255, read("Ff\r\n"));
assert_eq!(255, read("Ff \r\n"));
read_err("F\rF", InvalidInput);
read_err("F", UnexpectedEof);
read_err("X\r\n", InvalidInput);
read_err("1X\r\n", InvalidInput);
read_err("-\r\n", InvalidInput);
read_err("-1\r\n", InvalidInput);
assert_eq!(1, read("1;extension\r\n"));
assert_eq!(10, read("a;ext name=value\r\n"));
assert_eq!(1, read("1;extension;extension2\r\n"));
assert_eq!(1, read("1;;; ;\r\n"));
assert_eq!(2, read("2; extension...\r\n"));
assert_eq!(3, read("3 ; extension=123\r\n"));
assert_eq!(3, read("3 ;\r\n"));
assert_eq!(3, read("3 ; \r\n"));
read_err("1 invalid extension\r\n", InvalidInput);
read_err("1 A\r\n", InvalidInput);
read_err("1;no CRLF", UnexpectedEof);
read_err("1;reject\nnewlines\r\n", InvalidData);
read_err("f0000000000000003\r\n", InvalidData);
}
#[test]
fn test_read_chunked_single_read() {
let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\n");
match TransferCoding::decode_chunked(64).decode(mock_buf) {
ChunkResult::Ok(buf) => {
assert_eq!(16, buf.len());
let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
assert_eq!("1234567890abcdef", &result);
}
state => panic!("{}", state),
}
}
#[test]
fn test_read_chunked_trailer_with_missing_lf() {
let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n");
let mut decoder = TransferCoding::decode_chunked(64);
match decoder.decode(mock_buf) {
ChunkResult::Ok(_) => {}
state => panic!("{}", state),
}
match decoder.decode(mock_buf) {
ChunkResult::Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidInput),
state => panic!("{}", state),
}
}
#[test]
fn test_read_chunked_after_eof() {
let mock_buf = &mut BytesMut::from("10\r\n1234567890abcdef\r\n0\r\n\r\n");
let mut decoder = TransferCoding::decode_chunked(64);
match decoder.decode(mock_buf) {
ChunkResult::Ok(buf) => {
assert_eq!(16, buf.len());
let result = String::from_utf8(buf.as_ref().to_vec()).unwrap();
assert_eq!("1234567890abcdef", &result);
}
state => panic!("{}", state),
}
match decoder.decode(mock_buf) {
ChunkResult::OnEof => {}
state => panic!("{}", state),
}
match decoder.decode(mock_buf) {
ChunkResult::AlreadyEof => {}
state => panic!("{}", state),
}
}
#[test]
fn test_read_chunked_with_trailers() {
let mock_buf = &mut BytesMut::from(
"5\r\nHello\r\n0\r\nExpires: Wed, 21 Oct 2015 07:28:00 GMT\r\nX-Checksum: abc123\r\n\r\n",
);
let mut decoder = TransferCoding::decode_chunked(64);
match decoder.decode(mock_buf) {
ChunkResult::Ok(buf) => {
assert_eq!(buf.as_ref(), b"Hello");
}
state => panic!("expected data chunk, got: {}", state),
}
match decoder.decode(mock_buf) {
ChunkResult::Trailers(headers) => {
assert_eq!(headers.len(), 2);
assert_eq!(headers.get("Expires").unwrap(), "Wed, 21 Oct 2015 07:28:00 GMT");
assert_eq!(headers.get("X-Checksum").unwrap(), "abc123");
}
state => panic!("expected trailers, got: {}", state),
}
match decoder.decode(mock_buf) {
ChunkResult::OnEof => {}
state => panic!("expected OnEof, got: {}", state),
}
}
#[test]
fn encode_chunked() {
let mut encoder = TransferCoding::encode_chunked();
let dst = &mut BytesMut::default();
let msg1 = Bytes::from("foo bar");
encoder.encode(msg1, dst);
assert_eq!(dst.as_ref(), b"7\r\nfoo bar\r\n");
let msg2 = Bytes::from("baz quux herp");
encoder.encode(msg2, dst);
assert_eq!(dst.as_ref(), b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
encoder.encode_eof(None, dst);
assert_eq!(dst.as_ref(), b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n");
}
#[test]
fn encode_length() {
let max_len = 8;
let mut encoder = TransferCoding::length(max_len as u64);
let dst = &mut BytesMut::default();
let msg1 = Bytes::from("foo bar");
encoder.encode(msg1, dst);
assert_eq!(dst.as_ref(), b"foo bar");
for _ in 0..8 {
let msg2 = Bytes::from("baz");
encoder.encode(msg2, dst);
assert_eq!(dst.as_ref().len(), max_len);
assert_eq!(dst.as_ref(), b"foo barb");
}
encoder.encode_eof(None, dst);
assert_eq!(dst.as_ref().len(), max_len);
assert_eq!(dst.as_ref(), b"foo barb");
}
#[test]
fn encode_chunked_with_trailers() {
let mut encoder = TransferCoding::encode_chunked();
let dst = &mut BytesMut::default();
let msg = Bytes::from("hello");
encoder.encode(msg, dst);
let mut trailers = HeaderMap::new();
trailers.insert("x-checksum", HeaderValue::from_static("abc123"));
trailers.insert("x-status", HeaderValue::from_static("ok"));
trailers.append("x-status", HeaderValue::from_static("done"));
encoder.encode_eof(Some(trailers), dst);
assert_eq!(
dst.as_ref(),
b"5\r\nhello\r\n0\r\nx-checksum: abc123\r\nx-status: ok\r\nx-status: done\r\n\r\n"
);
}
#[test]
fn encode_chunked_trailers_filters_forbidden() {
use crate::http::header;
let mut encoder = TransferCoding::encode_chunked();
let dst = &mut BytesMut::default();
let msg = Bytes::from("data");
encoder.encode(msg, dst);
let mut trailers = HeaderMap::new();
trailers.insert("x-checksum", HeaderValue::from_static("abc123"));
trailers.insert(header::CONTENT_LENGTH, HeaderValue::from_static("99"));
trailers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
trailers.insert(header::HOST, HeaderValue::from_static("example.com"));
trailers.insert(header::TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
encoder.encode_eof(Some(trailers), dst);
assert_eq!(dst.as_ref(), b"4\r\ndata\r\n0\r\nx-checksum: abc123\r\n\r\n");
}
#[test]
fn encode_chunked_trailers_empty_after_filter() {
use crate::http::header;
let mut encoder = TransferCoding::encode_chunked();
let dst = &mut BytesMut::default();
encoder.encode(Bytes::from("x"), dst);
let mut trailers = HeaderMap::new();
trailers.insert(header::CONTENT_LENGTH, HeaderValue::from_static("1"));
encoder.encode_eof(Some(trailers), dst);
assert_eq!(dst.as_ref(), b"1\r\nx\r\n0\r\n\r\n");
}
}