use std::fmt;
use std::io::Write;
use http::{header, HeaderName, HeaderValue, Method};
use crate::chunk::Dechunker;
use crate::util::{compare_lowercase_ascii, log_data, Writer};
use crate::Error;
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct BodyWriter {
mode: SenderMode,
ended: bool,
}
#[derive(Debug, Clone, Copy)]
enum SenderMode {
None,
Sized(u64),
Chunked,
}
impl Default for SenderMode {
fn default() -> Self {
Self::None
}
}
pub(crate) const DEFAULT_CHUNK_SIZE: usize = 10 * 1024;
pub(crate) const DEFAULT_CHUNK_OVERHEAD: usize = 4 + 4;
pub(crate) const DEFAULT_CHUNK_AND_OVERHEAD: usize = DEFAULT_CHUNK_SIZE + DEFAULT_CHUNK_OVERHEAD;
impl BodyWriter {
pub fn new_none() -> Self {
BodyWriter {
mode: SenderMode::None,
ended: true,
}
}
pub fn new_chunked() -> Self {
BodyWriter {
mode: SenderMode::Chunked,
ended: false,
}
}
pub fn new_sized(size: u64) -> Self {
BodyWriter {
mode: SenderMode::Sized(size),
ended: false,
}
}
pub fn has_body(&self) -> bool {
matches!(self.mode, SenderMode::Sized(_) | SenderMode::Chunked)
}
pub fn is_chunked(&self) -> bool {
matches!(self.mode, SenderMode::Chunked)
}
pub fn write(&mut self, input: &[u8], w: &mut Writer) -> usize {
match &mut self.mode {
SenderMode::None => unreachable!(),
SenderMode::Sized(left) => {
let left_usize = (*left).min(usize::MAX as u64) as usize;
let to_write = w.available().min(input.len()).min(left_usize);
let success = w.try_write(|w| w.write_all(&input[..to_write]));
assert!(success);
*left -= to_write as u64;
if *left == 0 {
self.ended = true;
}
to_write
}
SenderMode::Chunked => {
let mut input_used = 0;
if input.is_empty() {
self.finish(w);
self.ended = true;
} else {
while write_chunk(
&input[input_used..],
&mut input_used,
w,
DEFAULT_CHUNK_SIZE,
) {}
}
input_used
}
}
}
fn finish(&self, w: &mut Writer) -> bool {
if self.is_chunked() {
let success = w.try_write(|w| w.write_all(b"0\r\n\r\n"));
if !success {
return false;
}
}
true
}
pub(crate) fn body_header(&self) -> (HeaderName, HeaderValue) {
match self.mode {
SenderMode::None => unreachable!(),
SenderMode::Sized(size) => (
header::CONTENT_LENGTH,
HeaderValue::from_str(&size.to_string()).unwrap(),
),
SenderMode::Chunked => (
header::TRANSFER_ENCODING,
HeaderValue::from_static("chunked"),
),
}
}
pub(crate) fn is_ended(&self) -> bool {
self.ended
}
pub(crate) fn left_to_send(&self) -> Option<u64> {
match self.mode {
SenderMode::Sized(v) => Some(v),
_ => None,
}
}
pub(crate) fn consume_direct_write(&mut self, amount: usize) {
match &mut self.mode {
SenderMode::None => unreachable!(),
SenderMode::Sized(left) => {
*left -= amount as u64;
if *left == 0 {
self.ended = true;
}
}
SenderMode::Chunked => unreachable!(),
}
}
}
#[allow(unused)]
pub(crate) fn calculate_chunk_overhead(output_len: usize) -> usize {
((output_len as f64).log(16.0) + 1.0).floor() as usize + 4
}
pub(crate) fn calculate_max_input(output_len: usize) -> usize {
let chunks = output_len / DEFAULT_CHUNK_AND_OVERHEAD;
let remaining = output_len % DEFAULT_CHUNK_AND_OVERHEAD;
let tail = if remaining <= DEFAULT_CHUNK_OVERHEAD {
0
} else {
remaining - DEFAULT_CHUNK_OVERHEAD
};
chunks * DEFAULT_CHUNK_SIZE + tail
}
fn write_chunk(input: &[u8], input_used: &mut usize, w: &mut Writer, max_chunk: usize) -> bool {
let available = w.available().saturating_sub(5);
let to_write = input.len().min(max_chunk).min(available);
let success = w.try_write(|w| {
write!(w, "{:0x?}\r\n", to_write)?;
w.write_all(&input[..to_write])?;
write!(w, "\r\n")
});
if success {
*input_used += to_write;
}
success && input.len() > to_write
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum BodyReader {
NoBody,
LengthDelimited(u64),
Chunked(Dechunker),
CloseDelimited,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BodyMode {
NoBody,
LengthDelimited(u64),
Chunked,
CloseDelimited,
}
impl BodyReader {
pub fn body_mode(&self) -> BodyMode {
match self {
BodyReader::NoBody => BodyMode::NoBody,
BodyReader::LengthDelimited(v) => BodyMode::LengthDelimited(*v),
BodyReader::Chunked(_) => BodyMode::Chunked,
BodyReader::CloseDelimited => BodyMode::CloseDelimited,
}
}
pub fn for_response<'a>(
http10: bool,
method: &Method,
status_code: u16,
header_lookup: &'a dyn Fn(HeaderName) -> Option<&'a str>,
) -> Result<Self, Error> {
let is_success = (200..=299).contains(&status_code);
let is_informational = (100..=199).contains(&status_code);
let is_redirect = (300..=399).contains(&status_code) && status_code != 304;
let header_defined = Self::header_defined(http10, header_lookup)?;
let has_body_header = header_defined != Self::CloseDelimited;
let has_no_body =
method == Method::HEAD ||
is_success && method == Method::CONNECT ||
is_informational ||
matches!(status_code, 204 | 304) ||
is_redirect && !has_body_header;
if has_no_body {
return Ok(Self::NoBody);
}
Ok(header_defined)
}
fn header_defined<'a>(
http10: bool,
header_lookup: &'a dyn Fn(HeaderName) -> Option<&'a str>,
) -> Result<Self, Error> {
let mut content_length: Option<u64> = None;
let mut chunked = false;
if let Some(value) = header_lookup(header::CONTENT_LENGTH) {
let v = value
.parse::<u64>()
.map_err(|_| Error::BadContentLengthHeader)?;
if content_length.is_some() {
return Err(Error::TooManyContentLengthHeaders);
}
content_length = Some(v);
}
if let Some(value) = header_lookup(header::TRANSFER_ENCODING) {
chunked = value
.split(',')
.map(|v| v.trim())
.any(|v| compare_lowercase_ascii(v, "chunked"));
}
if chunked && !http10 {
return Ok(Self::Chunked(Dechunker::new()));
}
if let Some(len) = content_length {
return Ok(Self::LengthDelimited(len));
}
Ok(Self::CloseDelimited)
}
pub fn read(
&mut self,
src: &[u8],
dst: &mut [u8],
stop_on_chunk_boundary: bool,
) -> Result<(usize, usize), Error> {
let part = match self {
BodyReader::LengthDelimited(_) => self.read_limit(src, dst),
BodyReader::Chunked(_) => self.read_chunked(src, dst, stop_on_chunk_boundary),
BodyReader::CloseDelimited => self.read_unlimit(src, dst),
BodyReader::NoBody => return Ok((0, 0)),
}?;
log_data(&src[..part.0]);
Ok(part)
}
fn read_limit(&mut self, src: &[u8], dst: &mut [u8]) -> Result<(usize, usize), Error> {
let left = match self {
BodyReader::LengthDelimited(v) => v,
_ => unreachable!(),
};
let left_usize = (*left).min(usize::MAX as u64) as usize;
let to_read = src.len().min(dst.len()).min(left_usize);
dst[..to_read].copy_from_slice(&src[..to_read]);
*left -= to_read as u64;
Ok((to_read, to_read))
}
fn read_chunked(
&mut self,
src: &[u8],
dst: &mut [u8],
stop_on_chunk_boundary: bool,
) -> Result<(usize, usize), Error> {
let dechunker = match self {
BodyReader::Chunked(v) => v,
_ => unreachable!(),
};
let mut input_used = 0;
let mut output_used = 0;
loop {
let (i, o) = dechunker.parse_input(&src[input_used..], &mut dst[output_used..])?;
input_used += i;
output_used += o;
if i == 0 || input_used == src.len() || output_used == dst.len() {
break;
}
if dechunker.is_ended() {
break;
}
if stop_on_chunk_boundary && dechunker.is_on_chunk_boundary() {
break;
}
}
Ok((input_used, output_used))
}
fn read_unlimit(&mut self, src: &[u8], dst: &mut [u8]) -> Result<(usize, usize), Error> {
let to_read = src.len().min(dst.len());
dst[..to_read].copy_from_slice(&src[..to_read]);
Ok((to_read, to_read))
}
pub fn is_ended(&self) -> bool {
match self {
BodyReader::NoBody => true,
BodyReader::LengthDelimited(v) => *v == 0,
BodyReader::Chunked(v) => v.is_ended(),
BodyReader::CloseDelimited => false,
}
}
pub(crate) fn is_on_chunk_boundary(&self) -> bool {
match self {
BodyReader::NoBody => false,
BodyReader::LengthDelimited(_) => false,
BodyReader::Chunked(v) => v.is_on_chunk_boundary(),
BodyReader::CloseDelimited => false,
}
}
}
impl fmt::Debug for BodyReader {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoBody => write!(f, "NoBody"),
Self::LengthDelimited(arg0) => f.debug_tuple("LengthDelimited").field(arg0).finish(),
Self::Chunked(_) => write!(f, "Chunked"),
Self::CloseDelimited => write!(f, "CloseDelimited"),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_calculate_max_input() {
assert_eq!(calculate_max_input(0), 0);
assert_eq!(calculate_max_input(1), 0);
assert_eq!(calculate_max_input(2), 0);
assert_eq!(calculate_max_input(9), 1);
assert_eq!(calculate_max_input(10), 2);
assert_eq!(calculate_max_input(11), 3);
assert_eq!(calculate_max_input(10247), 10239);
assert_eq!(calculate_max_input(10248), 10240);
assert_eq!(calculate_max_input(10249), 10240);
assert_eq!(calculate_max_input(10250), 10240);
assert_eq!(calculate_max_input(10257), 10241);
assert_eq!(calculate_max_input(10258), 10242);
assert_eq!(calculate_max_input(10259), 10243);
}
}