use std::pin::Pin;
use async_std::io;
use async_std::io::prelude::*;
use async_std::task::{Context, Poll};
use http_types::headers::{CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http_types::{Method, Response};
use crate::chunked::ChunkedEncoder;
use crate::date::fmt_http_date;
#[derive(Debug)]
pub struct Encoder {
depth: u16,
res: Response,
state: State,
bytes_written: usize,
head: Vec<u8>,
head_bytes_written: usize,
body_len: usize,
body_bytes_written: usize,
chunked: ChunkedEncoder,
method: Method,
}
#[derive(Debug)]
enum State {
Start,
ComputeHead,
EncodeHead,
EncodeFixedBody,
EncodeChunkedBody,
End,
}
impl Read for Encoder {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.bytes_written = 0;
let res = self.run(cx, buf);
log::trace!("ServerEncoder {} bytes written", self.bytes_written);
res
}
}
impl Encoder {
pub fn new(res: Response, method: Method) -> Self {
Self {
res,
depth: 0,
state: State::Start,
bytes_written: 0,
head: vec![],
head_bytes_written: 0,
body_len: 0,
body_bytes_written: 0,
chunked: ChunkedEncoder::new(),
method,
}
}
fn dispatch(
&mut self,
state: State,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
use State::*;
log::trace!("ServerEncoder state: {:?} -> {:?}", self.state, state);
#[cfg(debug_assertions)]
match self.state {
Start => assert!(matches!(state, ComputeHead)),
ComputeHead => assert!(matches!(state, EncodeHead)),
EncodeHead => assert!(matches!(state, EncodeChunkedBody | EncodeFixedBody | End)),
EncodeFixedBody => assert!(matches!(state, End)),
EncodeChunkedBody => assert!(matches!(state, End)),
End => panic!("No state transitions allowed after the ServerEncoder has ended"),
}
self.state = state;
self.run(cx, buf)
}
fn run(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
match self.state {
State::Start => self.dispatch(State::ComputeHead, cx, buf),
State::ComputeHead => self.compute_head(cx, buf),
State::EncodeHead => self.encode_head(cx, buf),
State::EncodeFixedBody => self.encode_fixed_body(cx, buf),
State::EncodeChunkedBody => self.encode_chunked_body(cx, buf),
State::End => Poll::Ready(Ok(self.bytes_written)),
}
}
fn compute_head(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let reason = self.res.status().canonical_reason();
let status = self.res.status();
std::io::Write::write_fmt(
&mut self.head,
format_args!("HTTP/1.1 {} {}\r\n", status, reason),
)?;
if let Some(len) = self.res.len() {
std::io::Write::write_fmt(&mut self.head, format_args!("content-length: {}\r\n", len))?;
} else {
std::io::Write::write_fmt(
&mut self.head,
format_args!("transfer-encoding: chunked\r\n"),
)?;
}
if self.res.header(DATE).is_none() {
let date = fmt_http_date(std::time::SystemTime::now());
std::io::Write::write_fmt(&mut self.head, format_args!("date: {}\r\n", date))?;
}
let iter = self
.res
.iter()
.filter(|(h, _)| h != &&CONTENT_LENGTH)
.filter(|(h, _)| h != &&TRANSFER_ENCODING);
for (header, values) in iter {
for value in values.iter() {
std::io::Write::write_fmt(
&mut self.head,
format_args!("{}: {}\r\n", header, value),
)?
}
}
std::io::Write::write_fmt(&mut self.head, format_args!("\r\n"))?;
self.dispatch(State::EncodeHead, cx, buf)
}
fn encode_head(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let head_len = self.head.len();
let len = std::cmp::min(head_len - self.head_bytes_written, buf.len());
let range = self.head_bytes_written..self.head_bytes_written + len;
buf[0..len].copy_from_slice(&self.head[range]);
self.bytes_written += len;
self.head_bytes_written += len;
if self.head_bytes_written == head_len {
if self.method == Method::Head {
self.dispatch(State::End, cx, buf)
} else {
match self.res.len() {
Some(body_len) => {
self.body_len = body_len;
self.dispatch(State::EncodeFixedBody, cx, buf)
}
None => self.dispatch(State::EncodeChunkedBody, cx, buf),
}
}
} else {
Poll::Ready(Ok(self.bytes_written))
}
}
fn encode_fixed_body(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
debug_assert!(self.bytes_written <= buf.len());
if self.bytes_written == buf.len() {
return Poll::Ready(Ok(self.bytes_written));
}
let upper_bound =
(self.bytes_written + self.body_len - self.body_bytes_written).min(buf.len());
let range = self.bytes_written..upper_bound;
let inner_poll_result = Pin::new(&mut self.res).poll_read(cx, &mut buf[range]);
let new_body_bytes_written = match inner_poll_result {
Poll::Ready(Ok(n)) => n,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => match self.bytes_written {
0 => return Poll::Pending,
n => return Poll::Ready(Ok(n)),
},
};
self.body_bytes_written += new_body_bytes_written;
self.bytes_written += new_body_bytes_written;
debug_assert!(
self.body_bytes_written <= self.body_len,
"Too many bytes read. Expected: {}, read: {}",
self.body_len,
self.body_bytes_written
);
if self.body_len == self.body_bytes_written {
self.dispatch(State::End, cx, buf)
} else if new_body_bytes_written == 0 {
self.dispatch(State::End, cx, buf)
} else {
self.encode_fixed_body(cx, buf)
}
}
fn encode_chunked_body(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let buf = &mut buf[self.bytes_written..];
match self.chunked.encode(&mut self.res, cx, buf) {
Poll::Ready(Ok(read)) => {
self.bytes_written += read;
match self.bytes_written {
0 => self.dispatch(State::End, cx, buf),
_ => Poll::Ready(Ok(self.bytes_written)),
}
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => match self.bytes_written {
0 => Poll::Pending,
_ => Poll::Ready(Ok(self.bytes_written)),
},
}
}
}