use std::fmt;
use std::io::Write;
use std::marker::PhantomData;
use http::{HeaderName, HeaderValue, Method, Request, Response, StatusCode, Version};
use crate::body::{BodyReader, BodyWriter};
use crate::parser::{try_parse_partial_response, try_parse_response};
use crate::util::{log_data, Writer};
use crate::{BodyMode, Error};
use super::amended::AmendedRequest;
use super::MAX_RESPONSE_HEADERS;
#[doc(hidden)]
pub mod state {
#[doc(hidden)]
pub struct WithoutBody(());
#[doc(hidden)]
pub struct WithBody(());
#[doc(hidden)]
pub struct RecvResponse(());
#[doc(hidden)]
pub struct RecvBody(());
}
use self::state::*;
pub struct Call<State, B> {
request: AmendedRequest<B>,
analyzed: bool,
state: BodyState,
_ph: PhantomData<State>,
}
impl<B> Call<(), B> {
pub fn without_body(request: Request<B>) -> Result<Call<WithoutBody, B>, Error> {
Call::new(request, BodyWriter::new_none())
}
pub fn with_body(request: Request<B>) -> Result<Call<WithBody, B>, Error> {
Call::new(request, BodyWriter::new_chunked())
}
}
impl<State, B> Call<State, B> {
fn new(request: Request<B>, default_body_mode: BodyWriter) -> Result<Self, Error> {
let request = AmendedRequest::new(request);
Ok(Call {
request,
analyzed: false,
state: BodyState {
writer: default_body_mode,
..Default::default()
},
_ph: PhantomData,
})
}
pub(crate) fn analyze_request(&mut self) -> Result<(), Error> {
if self.analyzed {
return Ok(());
}
let info = self
.request
.analyze(self.state.writer, self.state.skip_method_body_check)?;
if !info.req_host_header {
if let Some(host) = self.request.uri().host() {
let host =
HeaderValue::from_str(host).map_err(|e| Error::BadHeader(e.to_string()))?;
self.request.set_header("Host", host)?;
}
}
if !info.req_body_header && info.body_mode.has_body() {
let header = info.body_mode.body_header();
self.request.set_header(header.0, header.1)?;
}
self.state.writer = info.body_mode;
self.analyzed = true;
Ok(())
}
fn do_into_receive(self) -> Result<Call<RecvResponse, B>, Error> {
if !self.state.writer.is_ended() {
return Err(Error::UnfinishedRequest);
}
Ok(Call {
request: self.request,
analyzed: self.analyzed,
state: BodyState {
phase: Phase::RecvResponse,
..self.state
},
_ph: PhantomData,
})
}
pub(crate) fn amended(&self) -> &AmendedRequest<B> {
&self.request
}
pub(crate) fn amended_mut(&mut self) -> &mut AmendedRequest<B> {
&mut self.request
}
pub(crate) fn body_mode(&self) -> BodyMode {
self.state
.reader
.map(|r| r.body_mode())
.unwrap_or(BodyMode::Chunked)
}
}
#[derive(Debug, Default)]
struct BodyState {
phase: Phase,
writer: BodyWriter,
reader: Option<BodyReader>,
skip_method_body_check: bool,
}
impl BodyState {
fn need_response_body(&self) -> bool {
!matches!(
self.reader,
Some(BodyReader::NoBody) | Some(BodyReader::LengthDelimited(0))
)
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Phase {
SendLine,
SendHeaders(usize),
SendBody,
RecvResponse,
RecvBody,
}
impl Default for Phase {
fn default() -> Self {
Self::SendLine
}
}
impl Phase {
fn is_prelude(&self) -> bool {
matches!(self, Phase::SendLine | Phase::SendHeaders(_))
}
fn is_body(&self) -> bool {
matches!(self, Phase::SendBody)
}
}
impl<B> Call<WithoutBody, B> {
pub(crate) fn into_send_body(mut self) -> Call<WithBody, B> {
assert!(!self.analyzed);
self.state.skip_method_body_check = true;
Call {
request: self.request,
analyzed: self.analyzed,
state: self.state,
_ph: PhantomData,
}
}
pub fn write(&mut self, output: &mut [u8]) -> Result<usize, Error> {
self.analyze_request()?;
let mut w = Writer::new(output);
try_write_prelude(&self.request, &mut self.state, &mut w)?;
let output_used = w.len();
Ok(output_used)
}
pub fn is_finished(&self) -> bool {
!self.state.phase.is_prelude()
}
pub fn into_receive(self) -> Result<Call<RecvResponse, B>, Error> {
self.do_into_receive()
}
}
impl<B> Call<WithBody, B> {
pub fn write(&mut self, input: &[u8], output: &mut [u8]) -> Result<(usize, usize), Error> {
self.analyze_request()?;
let mut w = Writer::new(output);
let mut input_used = 0;
if self.is_prelude() {
try_write_prelude(&self.request, &mut self.state, &mut w)?;
} else if self.is_body() {
if !input.is_empty() && self.state.writer.is_ended() {
return Err(Error::BodyContentAfterFinish);
}
if let Some(left) = self.state.writer.left_to_send() {
if input.len() as u64 > left {
return Err(Error::BodyLargerThanContentLength);
}
}
input_used = self.state.writer.write(input, &mut w);
}
let output_used = w.len();
Ok((input_used, output_used))
}
pub(crate) fn consume_direct_write(&mut self, amount: usize) -> Result<(), Error> {
if let Some(left) = self.state.writer.left_to_send() {
if amount as u64 > left {
return Err(Error::BodyLargerThanContentLength);
}
} else {
return Err(Error::BodyIsChunked);
}
self.state.writer.consume_direct_write(amount);
Ok(())
}
pub(crate) fn is_prelude(&self) -> bool {
self.state.phase.is_prelude()
}
pub(crate) fn is_body(&self) -> bool {
self.state.phase.is_body()
}
pub(crate) fn is_chunked(&self) -> bool {
self.state.writer.is_chunked()
}
pub fn is_finished(&self) -> bool {
self.state.writer.is_ended()
}
pub fn into_receive(self) -> Result<Call<RecvResponse, B>, Error> {
self.do_into_receive()
}
}
fn try_write_prelude<B>(
request: &AmendedRequest<B>,
state: &mut BodyState,
w: &mut Writer,
) -> Result<(), Error> {
let at_start = w.len();
loop {
if try_write_prelude_part(request, state, w) {
continue;
}
let written = w.len() - at_start;
if written > 0 || state.phase.is_body() {
return Ok(());
} else {
return Err(Error::OutputOverflow);
}
}
}
fn try_write_prelude_part<Body>(
request: &AmendedRequest<Body>,
state: &mut BodyState,
w: &mut Writer,
) -> bool {
match &mut state.phase {
Phase::SendLine => {
let success = do_write_send_line(request.prelude(), w);
if success {
state.phase = Phase::SendHeaders(0);
}
success
}
Phase::SendHeaders(index) => {
let header_count = request.headers_len();
let all = request.headers();
let skipped = all.skip(*index);
do_write_headers(skipped, index, header_count - 1, w);
if *index == header_count {
state.phase = Phase::SendBody;
}
false
}
_ => false,
}
}
fn do_write_send_line(line: (&Method, &str, Version), w: &mut Writer) -> bool {
w.try_write(|w| write!(w, "{} {} {:?}\r\n", line.0, line.1, line.2))
}
fn do_write_headers<'a, I>(headers: I, index: &mut usize, last_index: usize, w: &mut Writer)
where
I: Iterator<Item = (&'a HeaderName, &'a HeaderValue)>,
{
for h in headers {
let success = w.try_write(|w| {
write!(w, "{}: ", h.0)?;
w.write_all(h.1.as_bytes())?;
write!(w, "\r\n")?;
if *index == last_index {
write!(w, "\r\n")?;
}
Ok(())
});
if success {
*index += 1;
} else {
break;
}
}
}
impl<B> Call<RecvResponse, B> {
pub fn try_response(&mut self, input: &[u8]) -> Result<Option<(usize, Response<()>)>, Error> {
let (input_used, response) = match try_parse_response::<MAX_RESPONSE_HEADERS>(input)? {
Some(v) => v,
None => {
if let Some(mut r) = try_parse_partial_response::<MAX_RESPONSE_HEADERS>(input)? {
let is_complete_redirection =
r.status().is_redirection() && r.headers().contains_key("location");
if is_complete_redirection {
debug!("Partial redirection response, insert fake connection: close");
r.headers_mut()
.insert("connection", HeaderValue::from_static("close"));
(input.len(), r)
} else {
return Ok(None);
}
} else {
return Ok(None);
}
}
};
log_data(&input[..input_used]);
let http10 = response.version() == Version::HTTP_10;
let status = response.status().as_u16();
if status == StatusCode::CONTINUE {
if !response.headers().is_empty() {
return Err(Error::HeadersWith100);
}
return Ok(Some((input_used, response)));
}
let header_lookup = |name: &str| {
if let Some(header) = response.headers().get(name) {
return header.to_str().ok();
}
None
};
let recv_body_mode =
BodyReader::for_response(http10, self.request.method(), status, &header_lookup)?;
self.state.reader = Some(recv_body_mode);
Ok(Some((input_used, response)))
}
pub fn is_finished(&self) -> bool {
self.state.reader.is_some()
}
pub fn into_body(self) -> Result<Option<Call<RecvBody, B>>, Error> {
let rbm = match &self.state.reader {
Some(v) => v,
None => return Err(Error::IncompleteResponse),
};
if matches!(rbm, BodyReader::NoBody) {
return Ok(None);
}
let next = self.do_into_body();
Ok(Some(next))
}
pub(crate) fn need_response_body(&self) -> bool {
self.state.need_response_body()
}
pub(crate) fn do_into_body(self) -> Call<RecvBody, B> {
Call {
request: self.request,
analyzed: self.analyzed,
state: BodyState {
phase: Phase::RecvBody,
..self.state
},
_ph: PhantomData,
}
}
}
impl<B> Call<RecvBody, B> {
pub fn read(&mut self, input: &[u8], output: &mut [u8]) -> Result<(usize, usize), Error> {
let rbm = self.state.reader.as_mut().unwrap();
if rbm.is_ended() {
return Ok((0, 0));
}
rbm.read(input, output)
}
pub fn is_ended(&self) -> bool {
let rbm = self.state.reader.as_ref().unwrap();
rbm.is_ended()
}
pub fn is_close_delimited(&self) -> bool {
let rbm = self.state.reader.as_ref().unwrap();
matches!(rbm, BodyReader::CloseDelimited)
}
}
impl<State, B> fmt::Debug for Call<State, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Call")
.field("phase", &self.state.phase)
.finish()
}
}
impl fmt::Debug for Phase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SendLine => write!(f, "SendLine"),
Self::SendHeaders(_) => write!(f, "SendHeaders"),
Self::SendBody => write!(f, "SendBody"),
Self::RecvResponse => write!(f, "RecvResponse"),
Self::RecvBody => write!(f, "RecvBody"),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use std::str;
use http::{Method, Request};
use crate::Error;
#[test]
fn ensure_send_sync() {
fn is_send_sync<T: Send + Sync>(_t: T) {}
is_send_sync(Call::without_body(Request::new(())).unwrap());
is_send_sync(Call::with_body(Request::post("/").body(()).unwrap()).unwrap());
}
#[test]
fn create_empty() {
let req = Request::builder().body(()).unwrap();
let _call = Call::without_body(req);
}
#[test]
fn create_streaming() {
let req = Request::builder().body(()).unwrap();
let _call = Call::with_body(req);
}
#[test]
fn head_simple() {
let req = Request::head("http://foo.test/page").body(()).unwrap();
let mut call = Call::without_body(req).unwrap();
let mut output = vec![0; 1024];
let n = call.write(&mut output).unwrap();
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(s, "HEAD /page HTTP/1.1\r\nhost: foo.test\r\n\r\n");
}
#[test]
fn head_with_body() {
let req = Request::head("http://foo.test/page").body(()).unwrap();
let mut call = Call::with_body(req).unwrap();
let err = call.write(&[], &mut []).unwrap_err();
assert_eq!(err, Error::MethodForbidsBody(Method::HEAD));
}
#[test]
fn post_simple() {
let req = Request::post("http://f.test/page")
.header("content-length", 5)
.body(())
.unwrap();
let mut call = Call::with_body(req).unwrap();
let mut output = vec![0; 1024];
let (i1, n1) = call.write(b"hallo", &mut output).unwrap();
let (i2, n2) = call.write(b"hallo", &mut output[n1..]).unwrap();
assert_eq!(i1, 0);
assert_eq!(i2, 5);
assert_eq!(n1, 56);
assert_eq!(n2, 5);
let s = str::from_utf8(&output[..n1 + n2]).unwrap();
assert_eq!(
s,
"POST /page HTTP/1.1\r\nhost: f.test\r\ncontent-length: 5\r\n\r\nhallo"
);
}
#[test]
fn post_small_output() {
let req = Request::post("http://f.test/page")
.header("content-length", 5)
.body(())
.unwrap();
let mut call = Call::with_body(req).unwrap();
let mut output = vec![0; 1024];
let body = b"hallo";
{
let (i, n) = call.write(body, &mut output[..25]).unwrap();
assert_eq!(i, 0);
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(s, "POST /page HTTP/1.1\r\n");
assert!(!call.is_finished());
}
{
let (i, n) = call.write(body, &mut output[..20]).unwrap();
assert_eq!(i, 0);
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(s, "host: f.test\r\n");
assert!(!call.is_finished());
}
{
let (i, n) = call.write(body, &mut output[..21]).unwrap();
assert_eq!(i, 0);
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(s, "content-length: 5\r\n\r\n");
assert!(!call.is_finished());
}
{
let (i, n) = call.write(body, &mut output[..25]).unwrap();
assert_eq!(n, 5);
assert_eq!(i, 5);
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(s, "hallo");
assert!(call.is_finished());
}
}
#[test]
fn post_with_short_content_length() {
let req = Request::post("http://f.test/page")
.header("content-length", 2)
.body(())
.unwrap();
let mut call = Call::with_body(req).unwrap();
let body = b"hallo";
let mut output = vec![0; 1024];
let r = call.write(body, &mut output);
assert!(r.is_ok());
let r = call.write(body, &mut output);
assert_eq!(r.unwrap_err(), Error::BodyLargerThanContentLength);
}
#[test]
fn post_with_short_body_input() {
let req = Request::post("http://f.test/page")
.header("content-length", 5)
.body(())
.unwrap();
let mut call = Call::with_body(req).unwrap();
let mut output = vec![0; 1024];
let (i1, n1) = call.write(b"ha", &mut output).unwrap();
let (i2, n2) = call.write(b"ha", &mut output[n1..]).unwrap();
assert_eq!(i1, 0);
assert_eq!(i2, 2);
assert_eq!(n1, 56);
assert_eq!(n2, 2);
let s = str::from_utf8(&output[..n1 + n2]).unwrap();
assert_eq!(
s,
"POST /page HTTP/1.1\r\nhost: f.test\r\ncontent-length: 5\r\n\r\nha"
);
assert!(!call.is_finished());
let (i, n2) = call.write(b"llo", &mut output).unwrap();
assert_eq!(i, 3);
let s = str::from_utf8(&output[..n2]).unwrap();
assert_eq!(s, "llo");
assert!(call.is_finished());
}
#[test]
fn post_with_chunked() {
let req = Request::post("http://f.test/page")
.header("transfer-encoding", "chunked")
.body(())
.unwrap();
let mut call = Call::with_body(req).unwrap();
let body = b"hallo";
let mut output = vec![0; 1024];
let (i1, n1) = call.write(body, &mut output).unwrap();
let (i2, n2) = call.write(body, &mut output[n1..]).unwrap();
let (_, n3) = call.write(&[], &mut output[n1 + n2..]).unwrap();
assert_eq!(i1, 0);
assert_eq!(i2, 5);
assert_eq!(n1, 65);
assert_eq!(n2, 10);
assert_eq!(n3, 5);
let s = str::from_utf8(&output[..n1 + n2 + n3]).unwrap();
assert_eq!(
s,
"POST /page HTTP/1.1\r\nhost: f.test\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhallo\r\n0\r\n\r\n"
);
}
#[test]
fn post_without_body() {
let req = Request::post("http://foo.test/page").body(()).unwrap();
let mut call = Call::without_body(req).unwrap();
let err = call.write(&mut []).unwrap_err();
assert_eq!(err, Error::MethodRequiresBody(Method::POST));
}
#[test]
fn post_streaming() {
let req = Request::post("http://f.test/page").body(()).unwrap();
let mut call = Call::with_body(req).unwrap();
let mut output = vec![0; 1024];
let (i1, n1) = call.write(b"hallo", &mut output).unwrap();
let (i2, n2) = call.write(b"hallo", &mut output[n1..]).unwrap();
let (i3, n3) = call.write(&[], &mut output[n1 + n2..]).unwrap();
assert_eq!(i1, 0);
assert_eq!(i2, 5);
assert_eq!(n1, 65);
assert_eq!(n2, 10);
assert_eq!(i3, 0);
assert_eq!(n3, 5);
let s = str::from_utf8(&output[..(n1 + n2 + n3)]).unwrap();
assert_eq!(
s,
"POST /page HTTP/1.1\r\nhost: f.test\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhallo\r\n0\r\n\r\n"
);
}
#[test]
fn post_streaming_with_size() {
let req = Request::post("http://f.test/page")
.header("content-length", "5")
.body(())
.unwrap();
let mut call = Call::with_body(req).unwrap();
let mut output = vec![0; 1024];
let (i1, n1) = call.write(b"hallo", &mut output).unwrap();
let (i2, n2) = call.write(b"hallo", &mut output[n1..]).unwrap();
assert_eq!(i1, 0);
assert_eq!(n1, 56);
assert_eq!(i2, 5);
assert_eq!(n2, 5);
let s = str::from_utf8(&output[..(n1 + n2)]).unwrap();
assert_eq!(
s,
"POST /page HTTP/1.1\r\nhost: f.test\r\ncontent-length: 5\r\n\r\nhallo"
);
}
#[test]
fn post_streaming_after_end() {
let req = Request::post("http://f.test/page").body(()).unwrap();
let mut call = Call::with_body(req).unwrap();
let mut output = vec![0; 1024];
let (_, n1) = call.write(b"hallo", &mut output).unwrap();
let (_, n2) = call.write(&[], &mut output[n1..]).unwrap();
let err = call.write(b"after end", &mut output[(n1 + n2)..]);
assert_eq!(err, Err(Error::BodyContentAfterFinish));
}
#[test]
fn post_streaming_too_much() {
let req = Request::post("http://f.test/page")
.header("content-length", "5")
.body(())
.unwrap();
let mut call = Call::with_body(req).unwrap();
let mut output = vec![0; 1024];
let (_, n1) = call.write(b"hallo", &mut output).unwrap();
let (_, n2) = call.write(b"hallo", &mut output[n1..]).unwrap();
let err = call.write(b"fail", &mut output[n1 + n2..]).unwrap_err();
assert_eq!(err, Error::BodyContentAfterFinish);
}
}