#![allow(clippy::borrow_interior_mutable_const)]
use futures_lite::AsyncRead;
use http::header;
use httpdate::fmt_http_date;
use std::pin::Pin;
use std::task::{Context, Poll};
use tracing::trace;
use crate::chunked::ChunkedEncoder;
use super::response_writer::InnerResponse;
pub(crate) struct Encoder {
resp: InnerResponse,
state: EncoderState,
bytes_read: usize,
head_buf: Vec<u8>,
head_bytes_read: usize,
content_length: Option<usize>,
body_bytes_read: usize,
chunked: ChunkedEncoder,
}
impl Encoder {
pub(crate) fn encode(resp: InnerResponse) -> Self {
let content_length = resp.body.length;
Self {
resp,
state: EncoderState::Start,
bytes_read: 0,
head_buf: Vec::new(),
head_bytes_read: 0,
content_length,
body_bytes_read: 0,
chunked: ChunkedEncoder::new(),
}
}
fn start(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
let version = self.resp.version;
let status = self.resp.status;
let date = if !self.resp.headers.contains_key(header::DATE) {
Some(fmt_http_date(std::time::SystemTime::now()))
} else {
None
};
#[allow(clippy::collapsible_if)]
if self.content_length.is_none() || matches!(self.content_length, Some(x) if x > 0) {
if !self.resp.headers.contains_key(header::CONTENT_TYPE) {
self.resp.headers.insert(header::CONTENT_TYPE, "application/octet-stream".parse().unwrap());
}
}
let headers = self
.resp
.headers
.iter()
.filter(|(h, _)| **h != header::CONTENT_LENGTH)
.filter(|(h, _)| **h != header::TRANSFER_ENCODING);
std::io::Write::write_fmt(
&mut self.head_buf,
format_args!("{:?} {}\r\n", version, status),
)?;
if let Some(len) = self.content_length {
std::io::Write::write_fmt(
&mut self.head_buf,
format_args!("content-length: {}\r\n", len),
)?;
} else {
std::io::Write::write_fmt(
&mut self.head_buf,
format_args!("transfer-encoding: chunked\r\n"),
)?;
}
if let Some(date) = date {
std::io::Write::write_fmt(&mut self.head_buf, format_args!("date: {}\r\n", date))?;
}
for (header, value) in headers {
std::io::Write::write_fmt(&mut self.head_buf, format_args!("{}: ", header))?;
std::io::Write::write(&mut self.head_buf, value.as_bytes())?;
std::io::Write::write(&mut self.head_buf, b"\r\n")?;
}
std::io::Write::write_fmt(&mut self.head_buf, format_args!("\r\n"))?;
self.state = EncoderState::Head;
self.encode_head(cx, buf)
}
fn encode_head(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let len = std::cmp::min(self.head_buf.len() - self.head_bytes_read, buf.len());
let range = self.head_bytes_read..self.head_bytes_read + len;
buf[0..len].copy_from_slice(&self.head_buf[range]);
self.bytes_read += len;
self.head_bytes_read += len;
if self.head_bytes_read == self.head_buf.len() {
match self.content_length {
Some(_) => {
self.state = EncoderState::FixedBody;
self.encode_fixed_body(cx, buf)
}
None => {
self.state = EncoderState::ChunkedBody;
trace!("Server response encoding: chunked body");
self.encode_chunked_body(cx, buf)
}
}
} else {
Poll::Ready(Ok(self.bytes_read))
}
}
fn encode_fixed_body(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
if self.bytes_read == buf.len() {
return Poll::Ready(Ok(self.bytes_read));
}
let content_length = self
.content_length
.expect("content_length.is_some() checked before entering method");
let upper_limit = std::cmp::min(
self.bytes_read + content_length - self.body_bytes_read,
buf.len(),
);
let range = self.bytes_read..upper_limit;
let inner_read = Pin::new(&mut self.resp.body).poll_read(cx, &mut buf[range]);
match inner_read {
Poll::Ready(Ok(n)) => {
self.bytes_read += n;
self.body_bytes_read += n;
}
Poll::Ready(Err(err)) => {
return Poll::Ready(Err(err));
}
Poll::Pending => match self.bytes_read {
0 => return Poll::Pending,
n => return Poll::Ready(Ok(n)),
},
}
if content_length == self.body_bytes_read {
self.state = EncoderState::Done;
Poll::Ready(Ok(self.bytes_read))
} else {
self.encode_fixed_body(cx, buf)
}
}
fn encode_chunked_body(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let buf = &mut buf[self.bytes_read..];
match self.chunked.encode(&mut self.resp.body, cx, buf) {
Poll::Ready(Ok(read)) => {
self.bytes_read += read;
if self.bytes_read == 0 {
self.state = EncoderState::Done
}
Poll::Ready(Ok(self.bytes_read))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
if self.bytes_read > 0 {
return Poll::Ready(Ok(self.bytes_read));
}
Poll::Pending
}
}
}
}
impl AsyncRead for Encoder {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
self.bytes_read = 0;
use EncoderState::*;
match self.state {
Start => self.start(cx, buf),
Head => self.encode_head(cx, buf),
FixedBody => self.encode_fixed_body(cx, buf),
ChunkedBody => self.encode_chunked_body(cx, buf),
Done => Poll::Ready(Ok(0)),
}
}
}
#[derive(Debug)]
enum EncoderState {
Start,
Head,
FixedBody,
ChunkedBody,
Done,
}