use std::fmt;
use std::io::Write;
use std::marker::PhantomData;
use amended::AmendedResponse;
use http::{Method, Response, StatusCode, Version};
use crate::body::{BodyReader, BodyWriter};
use crate::ext::{MethodExt, StatusCodeExt};
use crate::util::Writer;
use crate::{ArrayVec, CloseReason};
mod amended;
#[cfg(test)]
mod test;
pub const MAX_REQUEST_HEADERS: usize = 128;
pub struct Reply<State> {
inner: Inner,
_ph: PhantomData<State>,
}
#[derive(Debug)]
pub(crate) struct Inner {
pub phase: ResponsePhase,
pub state: BodyState,
pub response: Option<AmendedResponse>,
pub close_reason: ArrayVec<CloseReason, 4>,
pub force_recv_body: bool,
pub force_send_body: bool,
pub method: Option<Method>,
pub expect_100: bool,
pub expect_100_reject: bool,
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum ResponsePhase {
Status,
Headers(usize),
Body,
}
impl ResponsePhase {
fn is_prelude(&self) -> bool {
matches!(self, ResponsePhase::Status | ResponsePhase::Headers(_))
}
fn is_body(&self) -> bool {
matches!(self, ResponsePhase::Body)
}
}
#[derive(Debug, Default)]
pub(crate) struct BodyState {
reader: Option<BodyReader>,
writer: Option<BodyWriter>,
stop_on_chunk_boundary: bool,
}
#[doc(hidden)]
pub mod state {
pub(crate) trait Named {
fn name() -> &'static str;
}
macro_rules! reply_state {
($n:tt) => {
#[doc(hidden)]
pub struct $n(());
impl Named for $n {
fn name() -> &'static str {
stringify!($n)
}
}
};
}
reply_state!(RecvRequest);
reply_state!(Send100);
reply_state!(RecvBody);
reply_state!(ProvideResponse);
reply_state!(SendResponse);
reply_state!(SendBody);
reply_state!(Cleanup);
}
use self::state::*;
impl<S> Reply<S> {
fn wrap(inner: Inner) -> Reply<S>
where
S: Named,
{
let wrapped = Reply {
inner,
_ph: PhantomData,
};
debug!("{:?}", wrapped);
wrapped
}
#[cfg(test)]
pub(crate) fn inner(&self) -> &Inner {
&self.inner
}
}
mod recvreq;
pub enum RecvRequestResult {
Send100(Reply<Send100>),
RecvBody(Reply<RecvBody>),
ProvideResponse(Reply<ProvideResponse>),
}
mod send100;
fn append_request(inner: Inner, response: Response<()>) -> Inner {
let method_allows_body = inner.method.as_ref().unwrap().allow_request_body();
let status_allows_body = response.status().body_allowed();
let default_body_mode = if method_allows_body && status_allows_body {
BodyWriter::new_chunked()
} else {
BodyWriter::new_none()
};
Inner {
phase: inner.phase,
state: BodyState {
writer: Some(default_body_mode),
..inner.state
},
response: Some(AmendedResponse::new(response)),
force_recv_body: inner.force_recv_body,
force_send_body: inner.force_send_body,
close_reason: inner.close_reason,
method: inner.method,
expect_100: inner.expect_100,
expect_100_reject: inner.expect_100_reject,
}
}
fn do_write_send_line(line: (Version, StatusCode), w: &mut Writer, end_head: bool) -> bool {
w.try_write(|w| {
write!(
w,
"{:?} {} {}\r\n{}",
line.0,
line.1.as_str(),
line.1.canonical_reason().unwrap_or("Unknown"),
if end_head { "\r\n" } else { "" }
)
})
}
mod provres;
mod recvbody;
mod sendres;
pub enum SendResponseResult {
SendBody(Reply<SendBody>),
Cleanup(Reply<Cleanup>),
}
mod sendbody;
impl Reply<Cleanup> {
pub fn must_close_connection(&self) -> bool {
self.close_reason().is_some()
}
pub fn close_reason(&self) -> Option<&'static str> {
self.inner.close_reason.first().map(|s| s.explain())
}
}
impl<State: Named> fmt::Debug for Reply<State> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Reply<{}>", State::name())
}
}
impl fmt::Debug for ResponsePhase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ResponsePhase::Status => write!(f, "SendStatus"),
ResponsePhase::Headers(_) => write!(f, "SendHeaders"),
ResponsePhase::Body => write!(f, "SendBody"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::{Response, StatusCode};
use std::str;
#[test]
fn get_simple() {
let mut reply = Reply::new().unwrap();
let input = b"GET /page HTTP/1.1\r\n\
host: test.local\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 40);
assert_eq!(request.method(), "GET");
assert_eq!(request.uri().path(), "/page");
let reply = reply.proceed().unwrap();
let RecvRequestResult::ProvideResponse(reply) = reply else {
panic!("Expected ProvideResponse state");
};
let response = Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/plain")
.body(())
.unwrap();
let mut reply = reply.provide(response).unwrap();
let mut output = vec![0_u8; 1024];
let n = reply.write(&mut output).unwrap();
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(
s,
"HTTP/1.1 200 OK\r\n\
transfer-encoding: chunked\r\n\
content-type: text/plain\r\n\
\r\n"
);
}
#[test]
fn post_with_100_continue() {
let mut reply = Reply::new().unwrap();
let input = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n\
expect: 100-continue\r\n\
transfer-encoding: chunked\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 93); assert_eq!(request.method(), "POST");
assert_eq!(request.uri().path(), "/upload");
assert_eq!(request.headers().get("expect").unwrap(), "100-continue");
let reply = reply.proceed().unwrap();
let reply = match reply {
RecvRequestResult::Send100(r) => r,
_ => panic!("Expected Send100 state"),
};
let mut output = vec![0_u8; 1024];
let (n, reply) = reply.accept(&mut output).unwrap();
assert_eq!(&output[..n], b"HTTP/1.1 100 Continue\r\n\r\n");
let mut reply = reply;
let mut body_buf = vec![0_u8; 1024];
let input = b"5\r\nhello\r\n";
let (input_used, output_used) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(input_used, 10);
assert_eq!(&body_buf[..output_used], b"hello");
let input = b"0\r\n\r\n";
let (input_used, output_used) = reply.read(input, &mut body_buf[5..]).unwrap();
assert_eq!(input_used, 5);
assert_eq!(output_used, 0);
assert!(reply.is_ended());
}
#[test]
fn post_with_content_length() {
let mut reply = Reply::new().unwrap();
let input = b"POST /data HTTP/1.1\r\n\
host: test.local\r\n\
content-length: 11\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 61); assert_eq!(request.method(), "POST");
assert_eq!(request.uri().path(), "/data");
let reply = reply.proceed().unwrap();
let mut reply = match reply {
RecvRequestResult::RecvBody(r) => r,
_ => panic!("Expected RecvBody state"),
};
let mut body_buf = vec![0_u8; 1024];
let input = b"Hello World";
let (input_used, output_used) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(input_used, 11);
assert_eq!(&body_buf[..output_used], b"Hello World");
assert!(reply.is_ended());
}
#[test]
fn head_response_with_body_fails() {
let mut reply = Reply::new().unwrap();
let input = b"HEAD /status HTTP/1.1\r\n\
host: test.local\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 43); assert_eq!(request.method(), "HEAD");
assert_eq!(request.uri().path(), "/status");
let reply = reply.proceed().unwrap();
let RecvRequestResult::ProvideResponse(reply) = reply else {
panic!("Expected ProvideResponse state");
};
let response = Response::builder()
.status(StatusCode::OK)
.header("content-length", "1000") .body(())
.unwrap();
reply
.provide(response)
.expect_err("no body allowed on HEAD response");
}
#[test]
fn head_response_with_body_and_footgun() {
let mut reply = Reply::new().unwrap();
let input = b"HEAD /status HTTP/1.1\r\n\
host: test.local\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 43); assert_eq!(request.method(), "HEAD");
assert_eq!(request.uri().path(), "/status");
let reply = reply.proceed().unwrap();
let RecvRequestResult::ProvideResponse(mut reply) = reply else {
panic!("Expected ProvideResponse state");
};
let response = Response::builder()
.status(StatusCode::OK)
.header("content-length", "1000") .body(())
.unwrap();
reply.force_send_body();
let mut reply = reply.provide(response).unwrap();
let mut output = vec![0_u8; 1024];
let n = reply.write(&mut output).unwrap();
let s = str::from_utf8(&output[..n]).unwrap();
assert!(s.contains("content-length: 1000"));
assert!(!s.contains("transfer-encoding"));
}
#[test]
fn post_streaming() {
let mut reply = Reply::new().unwrap();
let input = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n\
transfer-encoding: chunked\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 71);
assert_eq!(request.method(), "POST");
assert_eq!(request.uri().path(), "/upload");
let reply = reply.proceed().unwrap();
let mut reply = match reply {
RecvRequestResult::RecvBody(r) => r,
_ => panic!("Expected RecvBody state"),
};
let mut body_buf = vec![0_u8; 1024];
let input = b"5\r\nhello\r\n";
let (input_used, output_used) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(input_used, 10);
assert_eq!(output_used, 5);
assert_eq!(&body_buf[..output_used], b"hello");
let input = b"0\r\n\r\n";
let (input_used, output_used) = reply.read(input, &mut body_buf[5..]).unwrap();
assert_eq!(input_used, 5);
assert_eq!(output_used, 0);
assert!(reply.is_ended());
}
#[test]
fn post_small_input() {
let mut reply = Reply::new().unwrap();
let input1 = b"POST /upload";
let (used1, req1) = reply.try_request(input1).unwrap();
assert_eq!(used1, 0);
assert!(req1.is_none());
let input2 = b"POST /upload HTTP/1.1\r\n";
let (used2, req2) = reply.try_request(input2).unwrap();
assert_eq!(used2, 0);
assert!(req2.is_none());
let input3 = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n";
let (used3, req3) = reply.try_request(input3).unwrap();
assert_eq!(used3, 0);
assert!(req3.is_none());
let input4 = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n\
\r\n";
let (used4, req4) = reply.try_request(input4).unwrap();
assert_eq!(used4, 43);
let request = req4.unwrap();
assert_eq!(request.method(), "POST");
assert_eq!(request.uri().path(), "/upload");
}
#[test]
fn post_with_short_content_length() {
let mut reply = Reply::new().unwrap();
let input = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n\
content-length: 2\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 62);
assert_eq!(request.method(), "POST");
let reply = reply.proceed().unwrap();
let mut reply = match reply {
RecvRequestResult::RecvBody(r) => r,
_ => panic!("Expected RecvBody state"),
};
let mut body_buf = vec![0_u8; 1024];
let input = b"hello";
let (i1, o1) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(i1, 2);
assert_eq!(o1, 2);
assert!(reply.is_ended());
}
#[test]
fn post_streaming_too_much() {
let mut reply = Reply::new().unwrap();
let input = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n\
content-length: 5\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 62);
assert_eq!(request.method(), "POST");
let reply = reply.proceed().unwrap();
let mut reply = match reply {
RecvRequestResult::RecvBody(r) => r,
_ => panic!("Expected RecvBody state"),
};
let mut body_buf = vec![0_u8; 1024];
let input = b"hello world"; let (input_used, output_used) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(input_used, 5);
assert_eq!(output_used, 5);
}
#[test]
fn post_streaming_after_end() {
let mut reply = Reply::new().unwrap();
let input = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n\
transfer-encoding: chunked\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 71);
assert_eq!(request.method(), "POST");
let reply = reply.proceed().unwrap();
let mut reply = match reply {
RecvRequestResult::RecvBody(r) => r,
_ => panic!("Expected RecvBody state"),
};
let mut body_buf = vec![0_u8; 1024];
let input = b"5\r\nhello\r\n";
let (input_used, output_used) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(input_used, 10);
assert_eq!(output_used, 5);
let input = b"0\r\n\r\n";
let (input_used, output_used) = reply.read(input, &mut body_buf[5..]).unwrap();
assert_eq!(input_used, 5);
assert_eq!(output_used, 0);
assert!(reply.is_ended());
let input = b"more data";
let (i1, o1) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(i1, 0);
assert_eq!(o1, 0);
}
#[test]
fn post_with_short_body_input() {
let mut reply = Reply::new().unwrap();
let input = b"POST /upload HTTP/1.1\r\n\
host: test.local\r\n\
content-length: 11\r\n\
\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 63);
assert_eq!(request.method(), "POST");
let reply = reply.proceed().unwrap();
let mut reply = match reply {
RecvRequestResult::RecvBody(r) => r,
_ => panic!("Expected RecvBody state"),
};
let mut body_buf = vec![0_u8; 1024];
let input = b"He";
let (input_used, output_used) = reply.read(input, &mut body_buf).unwrap();
assert_eq!(input_used, 2);
assert_eq!(output_used, 2);
assert_eq!(&body_buf[..output_used], b"He");
let input = b"llo ";
let (input_used, output_used) = reply.read(input, &mut body_buf[2..]).unwrap();
assert_eq!(input_used, 4);
assert_eq!(output_used, 4);
assert_eq!(&body_buf[..6], b"Hello ");
let input = b"World";
let (input_used, output_used) = reply.read(input, &mut body_buf[6..]).unwrap();
assert_eq!(input_used, 5);
assert_eq!(output_used, 5);
assert_eq!(&body_buf[..11], b"Hello World");
assert!(reply.is_ended());
}
#[test]
fn non_standard_method_is_ok() {
let mut reply = Reply::new().unwrap();
let input = b"FNORD /page HTTP/1.1\r\n\
host: test.local\r\n\
\r\n";
let result = reply.try_request(input);
assert!(result.is_ok());
}
#[test]
fn ensure_reasonable_stack_sizes() {
macro_rules! ensure {
($type:ty, $size:tt) => {
let sz = std::mem::size_of::<$type>();
assert!(
sz <= $size,
"Stack size of {} is too big {} > {}",
stringify!($type),
sz,
$size
);
};
}
ensure!(http::Response<()>, 300); ensure!(AmendedResponse, 400); ensure!(Inner, 600); ensure!(Reply<RecvRequest>, 600); }
#[test]
fn connect() {
let mut reply = Reply::new().unwrap();
let input = b"CONNECT example.com HTTP/1.1\r\nhost: example.com\r\n\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 51);
assert_eq!(request.method(), "CONNECT");
assert_eq!(request.uri().path(), "");
let RecvRequestResult::ProvideResponse(reply) = reply.proceed().unwrap() else {
panic!("Expected ProvideResponse state");
};
let response = Response::builder().status(StatusCode::OK).body(()).unwrap();
let mut reply = reply.provide(response).unwrap();
let mut output = vec![0_u8; 1024];
let n = reply.write(&mut output).unwrap();
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(s, "HTTP/1.1 200 OK\r\n\r\n");
let SendResponseResult::Cleanup(_reply) = reply.proceed() else {
panic!("Expected Cleanup state")
};
}
#[test]
fn connect_read_body() {
let mut reply = Reply::new().unwrap();
reply.force_recv_body();
let input =
b"CONNECT example.com HTTP/1.1\r\nhost: example.com\r\ncontent-length: 1024\r\n\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 73);
assert_eq!(request.method(), "CONNECT");
assert_eq!(request.uri().path(), "");
let RecvRequestResult::RecvBody(_reply) = reply.proceed().unwrap() else {
panic!("Expected RecvBody state");
};
}
#[test]
fn connect_send_body_fails() {
let mut reply = Reply::new().unwrap();
let input =
b"CONNECT example.com HTTP/1.1\r\nhost: example.com\r\ncontent-length: 1024\r\n\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 73);
assert_eq!(request.method(), "CONNECT");
assert_eq!(request.uri().path(), "");
let RecvRequestResult::ProvideResponse(reply) = reply.proceed().unwrap() else {
panic!("Expected ProvideResponse state");
};
let response = Response::builder()
.status(StatusCode::OK)
.header("content-length", 1024)
.body(())
.unwrap();
reply
.provide(response)
.expect_err("no body allowed on CONNECT response");
}
#[test]
fn connect_send_body_with_footgun() {
let mut reply = Reply::new().unwrap();
let input =
b"CONNECT example.com HTTP/1.1\r\nhost: example.com\r\ncontent-length: 1024\r\n\r\n";
let (input_used, request) = reply.try_request(input).unwrap();
let request = request.unwrap();
assert_eq!(input_used, 73);
assert_eq!(request.method(), "CONNECT");
assert_eq!(request.uri().path(), "");
let RecvRequestResult::ProvideResponse(mut reply) = reply.proceed().unwrap() else {
panic!("Expected ProvideResponse state");
};
let response = Response::builder()
.status(StatusCode::OK)
.header("content-length", 1024)
.body(())
.unwrap();
reply.force_send_body();
let mut reply = reply.provide(response).unwrap();
let mut output = vec![0_u8; 1024];
let n = reply.write(&mut output).unwrap();
let s = str::from_utf8(&output[..n]).unwrap();
assert_eq!(s, "HTTP/1.1 200 OK\r\ncontent-length: 1024\r\n\r\n");
let SendResponseResult::SendBody(_reply) = reply.proceed() else {
panic!("Expected SendBody state")
};
}
}