use crate::chunked::{ChunkedDecoder, ChunkedEncoder};
use crate::fast_buf::FastBuf;
use crate::share::is_closed_kind;
use crate::AsyncRead;
use crate::Error;
use futures_util::ready;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::str::FromStr;
use std::task::{Context, Poll};
pub(crate) enum LimitRead {
ChunkedDecoder(ChunkedDecoder),
ContentLength(ContentLengthRead),
ReadToEnd(ReadToEnd),
NoBody,
}
impl LimitRead {
pub fn from_headers(
headers: &http::HeaderMap<http::HeaderValue>,
is_server_response: bool,
) -> Self {
let ret = if is_chunked(headers) {
LimitRead::ChunkedDecoder(ChunkedDecoder::new())
} else if let Some(size) = get_as::<u64>(headers, "content-length") {
LimitRead::ContentLength(ContentLengthRead::new(size))
} else if is_server_response {
LimitRead::ReadToEnd(ReadToEnd::new())
} else {
LimitRead::NoBody
};
trace!("LimitRead from headers: {:?}", ret);
ret
}
pub fn is_no_body(&self) -> bool {
match &self {
LimitRead::ContentLength(r) => r.limit == 0,
LimitRead::NoBody => true,
_ => false,
}
}
pub fn is_complete(&self) -> bool {
match &self {
LimitRead::ChunkedDecoder(v) => v.is_end(),
LimitRead::ContentLength(v) => v.is_end(),
LimitRead::ReadToEnd(v) => v.is_end(),
LimitRead::NoBody => true,
}
}
pub fn body_size(&self) -> Option<u64> {
if let LimitRead::ContentLength(v) = &self {
return Some(v.limit);
}
None
}
pub fn is_reusable(&self) -> bool {
self.is_complete() && !self.is_read_to_end()
}
fn is_read_to_end(&self) -> bool {
if let LimitRead::ReadToEnd(_) = self {
return true;
}
false
}
pub fn can_read_entire_vec(&self) -> bool {
if let LimitRead::ContentLength(_) = self {
return true;
}
false
}
pub fn accept_entire_vec(&mut self, buf: &[u8]) {
if let LimitRead::ContentLength(v) = self {
v.total += buf.len() as u64;
} else {
panic!("accept_entire_vec with wrong type of writer");
}
}
pub fn poll_read<S: AsyncRead + Unpin>(
&mut self,
cx: &mut Context,
recv: &mut BufIo<S>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self {
LimitRead::ChunkedDecoder(v) => v.poll_read(cx, recv, buf),
LimitRead::ContentLength(v) => v.poll_read(cx, recv, buf),
LimitRead::ReadToEnd(v) => v.poll_read(cx, recv, buf),
LimitRead::NoBody => Ok(0).into(),
}
}
}
pub(crate) fn headers_indicate_body(headers: &http::HeaderMap<http::HeaderValue>) -> bool {
is_chunked(headers) || get_as::<u64>(headers, "content-length").is_some()
}
use crate::buf_reader::BufIo;
#[derive(Debug)]
pub struct ContentLengthRead {
limit: u64,
total: u64,
}
impl ContentLengthRead {
fn new(limit: u64) -> Self {
ContentLengthRead { limit, total: 0 }
}
fn is_end(&self) -> bool {
self.total == self.limit
}
fn poll_read<R: AsyncRead + Unpin>(
&mut self,
cx: &mut Context,
recv: &mut BufIo<R>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
assert!(!buf.is_empty(), "poll_read with len 0 buf");
let left = (self.limit - self.total).min(usize::max_value() as u64) as usize;
if left == 0 {
return Ok(0).into();
}
let max = buf.len().min(left);
let amount = ready!(Pin::new(&mut *recv).poll_read(cx, &mut buf[0..max]))?;
if left > 0 && amount == 0 {
let msg = format!(
"Partial body received {} bytes and expected {}",
self.total, self.limit
);
trace!("{}", msg);
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, msg)).into();
}
self.total += amount as u64;
Ok(amount).into()
}
}
pub(crate) struct ReadToEnd {
reached_end: bool,
}
impl ReadToEnd {
fn new() -> Self {
ReadToEnd { reached_end: false }
}
fn is_end(&self) -> bool {
self.reached_end
}
fn poll_read<R: AsyncRead + Unpin>(
&mut self,
cx: &mut Context,
recv: &mut BufIo<R>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
assert!(!buf.is_empty(), "poll_read with len 0 buf");
if self.reached_end {
return Ok(0).into();
}
match ready!(Pin::new(&mut *recv).poll_read(cx, buf)) {
Ok(amount) => {
if amount == 0 {
self.reached_end = true;
}
Ok(amount).into()
}
Err(e) => {
if is_closed_kind(e.kind()) {
self.reached_end = true;
Ok(0).into()
} else {
Err(e).into()
}
}
}
}
}
pub(crate) enum LimitWrite {
ChunkedEncoder,
ContentLength(ContentLengthWrite),
NoBody,
}
impl LimitWrite {
pub fn from_headers(headers: &http::HeaderMap<http::HeaderValue>) -> Self {
let ret = if is_chunked(headers) {
LimitWrite::ChunkedEncoder
} else if let Some(limit) = get_as::<u64>(headers, "content-length") {
LimitWrite::ContentLength(ContentLengthWrite::new(limit))
} else {
LimitWrite::NoBody
};
trace!("LimitWrite from headers: {:?}", ret);
ret
}
pub fn overhead(&self) -> usize {
match self {
LimitWrite::ChunkedEncoder => 32,
LimitWrite::ContentLength(_) => 0,
LimitWrite::NoBody => 0,
}
}
pub fn is_no_body(&self) -> bool {
match self {
LimitWrite::ContentLength(w) => w.limit == 0,
LimitWrite::NoBody => true,
_ => false,
}
}
pub fn can_write_entire_vec(&self) -> bool {
if let LimitWrite::ContentLength(_) = self {
return true;
}
false
}
pub fn accept_entire_vec(&mut self, buf: &[u8]) {
if let LimitWrite::ContentLength(v) = self {
v.total += buf.len() as u64;
} else {
panic!("accept_entire_vec with wrong type of writer");
}
}
pub fn write(&mut self, data: &[u8], out: &mut FastBuf) -> Result<(), Error> {
match self {
LimitWrite::ChunkedEncoder => ChunkedEncoder::write_chunk(data, out),
LimitWrite::ContentLength(v) => v.write(data, out),
LimitWrite::NoBody => Ok(()),
}
}
pub fn finish(&mut self, out: &mut FastBuf) -> Result<(), Error> {
match self {
LimitWrite::ChunkedEncoder => ChunkedEncoder::write_finish(out),
LimitWrite::ContentLength(_) => Ok(()),
LimitWrite::NoBody => Ok(()),
}
}
}
#[derive(Debug)]
pub(crate) struct ContentLengthWrite {
limit: u64,
total: u64,
}
impl ContentLengthWrite {
fn new(limit: u64) -> Self {
ContentLengthWrite { limit, total: 0 }
}
fn write(&mut self, data: &[u8], out: &mut FastBuf) -> Result<(), Error> {
if data.is_empty() {
return Ok(());
}
self.total += data.len() as u64;
if self.total > self.limit {
let m = format!(
"Body data longer than content-length header: {} > {}",
self.total, self.limit
);
return Err(Error::User(m));
}
let mut into = out.borrow();
into.extend_from_slice(data);
Ok(())
}
}
impl fmt::Debug for LimitRead {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
LimitRead::ChunkedDecoder(_) => write!(f, "ChunkedDecoder")?,
LimitRead::ContentLength(l) => write!(f, "ContenLength({})", l.limit)?,
LimitRead::ReadToEnd(_) => write!(f, "ReadToEnd")?,
LimitRead::NoBody => write!(f, "NoBody")?,
}
Ok(())
}
}
impl fmt::Debug for LimitWrite {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
LimitWrite::ChunkedEncoder => write!(f, "ChunkedEncoder")?,
LimitWrite::ContentLength(l) => write!(f, "ContentLength({})", l.limit)?,
LimitWrite::NoBody => write!(f, "NoBody")?,
}
Ok(())
}
}
fn is_chunked(headers: &http::HeaderMap<http::HeaderValue>) -> bool {
headers
.get("transfer-encoding")
.and_then(|h| h.to_str().ok())
.map(|h| !h.contains("identity"))
.unwrap_or(false)
}
pub fn allow_reuse(headers: &http::HeaderMap<http::HeaderValue>, version: http::Version) -> bool {
if version == http::Version::HTTP_11 {
is_keep_alive(headers, true)
} else {
is_keep_alive(headers, false)
}
}
fn is_keep_alive(headers: &http::HeaderMap<http::HeaderValue>, default: bool) -> bool {
headers
.get("connection")
.and_then(|h| h.to_str().ok())
.and_then(|h| {
if h == "keep-alive" {
Some(true)
} else if h == "close" {
Some(false)
} else {
None
}
})
.unwrap_or(default)
}
fn get_str<'a>(headers: &'a http::HeaderMap, key: &str) -> Option<&'a str> {
headers.get(key).and_then(|v| v.to_str().ok())
}
fn get_as<T: FromStr>(headers: &http::HeaderMap, key: &str) -> Option<T> {
get_str(headers, key).and_then(|v| v.parse().ok())
}