use bytes::{BufMut, Bytes, BytesMut};
use http::{header, header::AsHeaderName, HeaderValue, StatusCode, Version};
use log::{debug, trace};
use pingora_error::{Error, ErrorType::*, OrErr, Result, RetryType};
use pingora_http::{HMap, IntoCaseHeaderName, RequestHeader, ResponseHeader};
use pingora_timeout::timeout;
use std::io::ErrorKind;
use std::str;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::body::{BodyReader, BodyWriter};
use super::common::*;
use crate::protocols::http::HttpTask;
use crate::protocols::{Digest, SocketAddr, Stream, UniqueID, UniqueIDType};
use crate::utils::{BufRef, KVRef};
pub struct HttpSession {
buf: Bytes,
pub(crate) underlying_stream: Stream,
raw_header: Option<BufRef>,
preread_body: Option<BufRef>,
body_reader: BodyReader,
body_writer: BodyWriter,
pub read_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
keepalive_timeout: KeepaliveStatus,
pub(crate) digest: Box<Digest>,
response_header: Option<Box<ResponseHeader>>,
request_written: Option<Box<RequestHeader>>,
bytes_sent: usize,
body_recv: usize,
upgraded: bool,
received_upgrade_req_body: bool,
close_delimited_resp: bool,
allow_h1_response_invalid_content_length: bool,
}
impl HttpSession {
pub fn new(stream: Stream) -> Self {
let digest = Box::new(Digest {
ssl_digest: stream.get_ssl_digest(),
timing_digest: stream.get_timing_digest(),
proxy_digest: stream.get_proxy_digest(),
socket_digest: stream.get_socket_digest(),
});
HttpSession {
underlying_stream: stream,
buf: Bytes::new(), raw_header: None,
preread_body: None,
body_reader: BodyReader::new(true),
body_writer: BodyWriter::new(),
keepalive_timeout: KeepaliveStatus::Off,
response_header: None,
request_written: None,
read_timeout: None,
write_timeout: None,
digest,
bytes_sent: 0,
body_recv: 0,
upgraded: false,
received_upgrade_req_body: false,
close_delimited_resp: false,
allow_h1_response_invalid_content_length: false,
}
}
pub fn new_with_options<P: crate::upstreams::peer::Peer>(stream: Stream, peer: &P) -> Self {
let mut session = Self::new(stream);
if let Some(options) = peer.get_peer_options() {
session.set_allow_h1_response_invalid_content_length(
options.allow_h1_response_invalid_content_length,
);
}
session
}
pub async fn write_request_header(&mut self, req: Box<RequestHeader>) -> Result<usize> {
self.init_req_body_writer(&req);
let to_wire = http_req_header_to_wire(&req).unwrap();
trace!("Writing request header: {to_wire:?}");
let write_fut = self.underlying_stream.write_all(to_wire.as_ref());
match self.write_timeout {
Some(t) => match timeout(t, write_fut).await {
Ok(res) => res,
Err(_) => Err(std::io::Error::from(ErrorKind::TimedOut)),
},
None => write_fut.await,
}
.map_err(|e| match e.kind() {
ErrorKind::TimedOut => {
Error::because(WriteTimedout, "while writing request headers (timeout)", e)
}
_ => Error::because(WriteError, "while writing request headers", e),
})?;
self.underlying_stream
.flush()
.await
.or_err(WriteError, "flushing request header")?;
self.request_written = Some(req);
Ok(to_wire.len())
}
async fn do_write_body(&mut self, buf: &[u8]) -> Result<Option<usize>> {
let written = self
.body_writer
.write_body(&mut self.underlying_stream, buf)
.await;
if let Ok(Some(num_bytes)) = written {
self.bytes_sent += num_bytes;
}
written
}
pub async fn write_body(&mut self, buf: &[u8]) -> Result<Option<usize>> {
match self.write_timeout {
Some(t) => match timeout(t, self.do_write_body(buf)).await {
Ok(res) => res,
Err(_) => Error::e_explain(WriteTimedout, format!("writing body, timeout: {t:?}")),
},
None => self.do_write_body(buf).await,
}
}
fn maybe_force_close_body_reader(&mut self) {
if self.upgraded && self.received_upgrade_req_body && !self.body_reader.body_done() {
self.body_reader.init_content_length(0, b"");
}
}
pub async fn finish_body(&mut self) -> Result<Option<usize>> {
let res = self.body_writer.finish(&mut self.underlying_stream).await?;
self.underlying_stream
.flush()
.await
.or_err(WriteError, "flushing body")?;
self.maybe_force_close_body_reader();
Ok(res)
}
fn validate_response(&self) -> Result<()> {
let resp_header = self
.response_header
.as_ref()
.expect("response header must be read");
super::common::check_dup_content_length(&resp_header.headers)?;
if !self.allow_h1_response_invalid_content_length {
self.get_content_length()?;
}
Ok(())
}
pub async fn read_response(&mut self) -> Result<usize> {
if self.preread_body.as_ref().is_none_or(|b| b.is_empty()) {
self.buf.clear();
}
let mut buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE);
let mut already_read: usize = 0;
loop {
if already_read > MAX_HEADER_SIZE {
return Error::e_explain(
InvalidHTTPHeader,
format!("Response header larger than {MAX_HEADER_SIZE}"),
);
}
let preread = self.preread_body.take();
let read_result = if let Some(preread) = preread.filter(|b| !b.is_empty()) {
buf.put_slice(preread.get(&self.buf));
Ok(preread.len())
} else {
let read_fut = self.underlying_stream.read_buf(&mut buf);
match self.read_timeout {
Some(t) => timeout(t, read_fut).await.map_err(|_| {
Error::explain(ReadTimedout, "while reading response headers")
})?,
None => read_fut.await,
}
};
let n = match read_result {
Ok(n) => match n {
0 => {
let mut e = Error::explain(
ConnectionClosed,
format!(
"while reading response headers, bytes already read: {already_read}",
),
);
e.retry = RetryType::ReusedOnly;
return Err(e);
}
_ => {
n
}
},
Err(e) => {
let true_io_error = e.raw_os_error().is_some();
let mut e = Error::because(
ReadError,
format!(
"while reading response headers, bytes already read: {already_read}",
),
e,
);
if true_io_error {
e.retry = RetryType::ReusedOnly;
} return Err(e);
}
};
already_read += n;
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut resp = httparse::Response::new(&mut headers);
let parsed = parse_resp_buffer(&mut resp, &buf);
match parsed {
HeaderParseState::Complete(s) => {
self.raw_header = Some(BufRef(0, s));
self.preread_body = Some(BufRef(s, already_read));
let base = buf.as_ptr() as usize;
let mut header_refs = Vec::<KVRef>::with_capacity(resp.headers.len());
let _num_headers = populate_headers(base, &mut header_refs, resp.headers);
let mut response_header = Box::new(ResponseHeader::build(
resp.code.unwrap(),
Some(resp.headers.len()),
)?);
response_header.set_version(match resp.version {
Some(1) => Version::HTTP_11,
Some(0) => Version::HTTP_10,
_ => Version::HTTP_09,
});
response_header.set_reason_phrase(resp.reason)?;
let buf = buf.freeze();
for header in header_refs {
let header_name = header.get_name_bytes(&buf);
let header_name = header_name.into_case_header_name();
let value_bytes = header.get_value_bytes(&buf);
let header_value = if cfg!(debug_assertions) {
if let Some(p) = value_bytes.windows(CRLF.len()).position(|w| w == CRLF)
{
let mut new_header = Vec::from_iter(value_bytes);
new_header[p] = b' ';
new_header[p + 1] = b' ';
unsafe {
http::HeaderValue::from_maybe_shared_unchecked(new_header)
}
} else {
unsafe {
http::HeaderValue::from_maybe_shared_unchecked(value_bytes)
}
}
} else {
unsafe { http::HeaderValue::from_maybe_shared_unchecked(value_bytes) }
};
response_header
.append_header(header_name, header_value)
.or_err(InvalidHTTPHeader, "while parsing request header")?;
}
let contains_transfer_encoding = response_header
.headers
.contains_key(header::TRANSFER_ENCODING);
let contains_content_length =
response_header.headers.contains_key(header::CONTENT_LENGTH);
if contains_content_length && contains_transfer_encoding {
response_header.remove_header(&header::CONTENT_LENGTH);
}
self.buf = buf;
self.response_header = Some(response_header);
self.validate_response()?;
self.upgraded = self
.is_upgrade(self.response_header.as_deref().expect("init above"))
.unwrap_or(false);
self.init_body_reader();
return Ok(s);
}
HeaderParseState::Partial => { }
HeaderParseState::Invalid(e) => {
return Error::e_because(
InvalidHTTPHeader,
format!("buf: {}", buf.escape_ascii()),
e,
);
}
}
}
}
pub async fn read_resp_header_parts(&mut self) -> Result<Box<ResponseHeader>> {
self.read_response().await?;
Ok(Box::new(self.resp_header().unwrap().clone()))
}
pub fn resp_header(&self) -> Option<&ResponseHeader> {
self.response_header.as_deref()
}
pub fn get_header(&self, name: impl AsHeaderName) -> Option<&HeaderValue> {
self.response_header
.as_ref()
.and_then(|h| h.headers.get(name))
}
pub fn get_header_bytes(&self, name: impl AsHeaderName) -> &[u8] {
self.get_header(name).map_or(b"", |v| v.as_bytes())
}
pub fn get_status(&self) -> Option<StatusCode> {
self.response_header.as_ref().map(|h| h.status)
}
async fn do_read_body(&mut self) -> Result<Option<BufRef>> {
self.init_body_reader();
self.body_reader
.read_body(&mut self.underlying_stream)
.await
}
pub async fn read_body_ref(&mut self) -> Result<Option<&[u8]>> {
let result = match self.read_timeout {
Some(t) => match timeout(t, self.do_read_body()).await {
Ok(res) => res,
Err(_) => Error::e_explain(ReadTimedout, format!("reading body, timeout: {t:?}")),
},
None => self.do_read_body().await,
};
result.map(|maybe_body| {
maybe_body.map(|body_ref| {
let slice = self.body_reader.get_body(&body_ref);
self.body_recv = self.body_recv.saturating_add(slice.len());
slice
})
})
}
pub async fn read_body_bytes(&mut self) -> Result<Option<Bytes>> {
let read = self.read_body_ref().await?;
Ok(read.map(Bytes::copy_from_slice))
}
pub fn body_bytes_received(&self) -> usize {
self.body_recv
}
pub fn is_body_done(&mut self) -> bool {
self.init_body_reader();
self.body_reader.body_done()
}
pub fn set_allow_h1_response_invalid_content_length(&mut self, allow: bool) {
self.allow_h1_response_invalid_content_length = allow;
}
pub(super) fn get_headers_raw(&self) -> &[u8] {
self.raw_header.as_ref().unwrap().get(&self.buf[..])
}
pub fn get_headers_raw_bytes(&self) -> Bytes {
self.raw_header.as_ref().unwrap().get_bytes(&self.buf)
}
fn set_keepalive(&mut self, seconds: Option<u64>) {
match seconds {
Some(sec) => {
if sec > 0 {
self.keepalive_timeout = KeepaliveStatus::Timeout(Duration::from_secs(sec));
} else {
self.keepalive_timeout = KeepaliveStatus::Infinite;
}
}
None => {
self.keepalive_timeout = KeepaliveStatus::Off;
}
}
}
pub fn respect_keepalive(&mut self) {
if self.upgraded || self.get_status() == Some(StatusCode::SWITCHING_PROTOCOLS) {
self.set_keepalive(None);
return;
}
if self.body_reader.need_init() || self.close_delimited_resp {
self.set_keepalive(None);
return;
}
if self.body_reader.has_bytes_overread() {
self.set_keepalive(None);
return;
}
if self.resp_header().map(|h| h.version) == Some(Version::HTTP_10)
&& self
.resp_header()
.and_then(|h| h.headers.get(header::TRANSFER_ENCODING))
.is_some()
{
self.set_keepalive(None);
return;
}
if let Some(keepalive) = self.is_connection_keepalive() {
if keepalive {
let (timeout, _max_use) = self.get_keepalive_values();
match timeout {
Some(d) => self.set_keepalive(Some(d)),
None => self.set_keepalive(Some(0)), }
} else {
self.set_keepalive(None);
}
} else if self.resp_header().map(|h| h.version) == Some(Version::HTTP_11) {
self.set_keepalive(Some(0)); } else {
self.set_keepalive(None); }
}
pub fn will_keepalive(&self) -> bool {
!matches!(self.keepalive_timeout, KeepaliveStatus::Off)
}
fn is_connection_keepalive(&self) -> Option<bool> {
let request_keepalive = self
.request_written
.as_ref()
.and_then(|req| is_buf_keepalive(req.headers.get(header::CONNECTION)));
match request_keepalive {
Some(false) => Some(false),
_ => is_buf_keepalive(self.get_header(header::CONNECTION)),
}
}
fn get_keepalive_values(&self) -> (Option<u64>, Option<usize>) {
let Some(keep_alive_header) = self.get_header("Keep-Alive") else {
return (None, None);
};
let Ok(header_value) = str::from_utf8(keep_alive_header.as_bytes()) else {
return (None, None);
};
let mut timeout = None;
let mut max = None;
for param in header_value.split(',') {
let parts = param.split_once('=').map(|(k, v)| (k.trim(), v));
match parts {
Some(("timeout", timeout_value)) => timeout = timeout_value.trim().parse().ok(),
Some(("max", max_value)) => max = max_value.trim().parse().ok(),
_ => {}
}
}
(timeout, max)
}
pub async fn shutdown(&mut self) {
let _ = self.underlying_stream.shutdown().await;
}
pub async fn reuse(mut self) -> Option<Stream> {
match self.keepalive_timeout {
KeepaliveStatus::Off => {
debug!("HTTP shutdown connection");
self.shutdown().await;
None
}
_ => Some(self.underlying_stream),
}
}
fn init_body_reader(&mut self) {
if self.body_reader.need_init() {
let preread_body = self.preread_body.as_ref().unwrap().get(&self.buf[..]);
if let Some(req) = self.request_written.as_ref() {
if req.method == http::method::Method::HEAD {
self.body_reader.init_content_length(0, preread_body);
return;
}
}
let upgraded = if let Some(code) = self.get_status() {
match code.as_u16() {
101 => self.is_upgrade_req(),
100..=199 => {
return;
}
204 | 304 => {
self.body_reader.init_content_length(0, preread_body);
return;
}
_ => false,
}
} else {
false
};
if upgraded {
self.body_reader.init_close_delimited(preread_body);
self.close_delimited_resp = true;
} else if self.is_chunked_encoding() {
self.body_reader.init_chunked(preread_body);
} else if let Some(cl) = self.get_content_length().unwrap_or(None) {
self.body_reader.init_content_length(cl, preread_body);
} else {
self.body_reader.init_close_delimited(preread_body);
self.close_delimited_resp = true;
}
}
}
pub fn is_upgrade_req(&self) -> bool {
match self.request_written.as_deref() {
Some(req) => is_upgrade_req(req),
None => false,
}
}
fn is_upgrade(&self, header: &ResponseHeader) -> Option<bool> {
if self.is_upgrade_req() {
Some(is_upgrade_resp(header))
} else {
None
}
}
pub fn was_upgraded(&self) -> bool {
self.upgraded
}
pub fn maybe_upgrade_body_writer(&mut self) {
if self.was_upgraded() {
self.received_upgrade_req_body = true;
self.body_writer.convert_to_close_delimited();
}
}
fn get_content_length(&self) -> Result<Option<usize>> {
buf_to_content_length(
self.get_header(header::CONTENT_LENGTH)
.map(|v| v.as_bytes()),
)
}
fn is_chunked_encoding(&self) -> bool {
self.resp_header()
.map(|h| is_chunked_encoding_from_headers(&h.headers))
.unwrap_or(false)
}
fn init_req_body_writer(&mut self, header: &RequestHeader) {
self.init_body_writer_comm(&header.headers)
}
fn init_body_writer_comm(&mut self, headers: &HMap) {
if is_chunked_encoding_from_headers(headers) {
self.body_writer.init_chunked();
} else {
let content_length =
header_value_content_length(headers.get(http::header::CONTENT_LENGTH));
match content_length {
Some(length) => {
self.body_writer.init_content_length(length);
}
None => {
self.body_writer.init_content_length(0);
}
}
}
}
fn should_read_resp_header(&self) -> bool {
match self.get_status().map(|s| s.as_u16()) {
Some(101) => false, Some(100..=199) => true, Some(_) => false,
None => true, }
}
pub async fn read_response_task(&mut self) -> Result<HttpTask> {
if self.should_read_resp_header() {
let resp_header = self.read_resp_header_parts().await?;
let end_of_body = self.is_body_done();
debug!("Response header: {resp_header:?}");
trace!(
"Raw Response header: {:?}",
str::from_utf8(self.get_headers_raw()).unwrap()
);
Ok(HttpTask::Header(resp_header, end_of_body))
} else if self.is_body_done() {
debug!("Response is done");
Ok(HttpTask::Done)
} else {
let body = self.read_body_bytes().await?;
let end_of_body = self.is_body_done();
debug!(
"Response body: {} bytes, end: {end_of_body}",
body.as_ref().map_or(0, |b| b.len())
);
trace!("Response body: {body:?}, upgraded: {}", self.upgraded);
if self.upgraded {
Ok(HttpTask::UpgradedBody(body, end_of_body))
} else {
Ok(HttpTask::Body(body, end_of_body))
}
}
}
pub fn digest(&self) -> &Digest {
&self.digest
}
pub fn digest_mut(&mut self) -> &mut Digest {
&mut self.digest
}
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.digest()
.socket_digest
.as_ref()
.map(|d| d.peer_addr())?
}
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.digest()
.socket_digest
.as_ref()
.map(|d| d.local_addr())?
}
pub fn stream(&self) -> &Stream {
&self.underlying_stream
}
pub fn into_inner(self) -> Stream {
self.underlying_stream
}
}
#[inline]
fn parse_resp_buffer<'buf>(
resp: &mut httparse::Response<'_, 'buf>,
buf: &'buf [u8],
) -> HeaderParseState {
let mut parser = httparse::ParserConfig::default();
parser.allow_spaces_after_header_name_in_responses(true);
parser.allow_obsolete_multiline_headers_in_responses(true);
let res = match parser.parse_response(resp, buf) {
Ok(s) => s,
Err(e) => {
return HeaderParseState::Invalid(e);
}
};
match res {
httparse::Status::Complete(s) => HeaderParseState::Complete(s),
_ => HeaderParseState::Partial,
}
}
#[inline]
pub fn http_req_header_to_wire(req: &RequestHeader) -> Option<BytesMut> {
let mut buf = BytesMut::with_capacity(512);
let method = req.method.as_str().as_bytes();
buf.put_slice(method);
buf.put_u8(b' ');
buf.put_slice(req.raw_path());
buf.put_u8(b' ');
let version = match req.version {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2",
_ => {
return None;
}
};
buf.put_slice(version.as_bytes());
buf.put_slice(CRLF);
req.header_to_h1_wire(&mut buf);
buf.put_slice(CRLF);
Some(buf)
}
impl UniqueID for HttpSession {
fn id(&self) -> UniqueIDType {
self.underlying_stream.id()
}
}
#[cfg(test)]
mod tests_stream {
use super::*;
use crate::protocols::http::v1::body::{BodyMode, ParseState};
use crate::upstreams::peer::PeerOptions;
use crate::ErrorType;
use rstest::rstest;
use tokio_test::io::Builder;
fn init_log() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn read_basic_response() {
init_log();
let input = b"HTTP/1.1 200 OK\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
assert_eq!(0, http_stream.resp_header().unwrap().headers.len());
}
#[tokio::test]
async fn read_response_custom_reason() {
init_log();
let input = b"HTTP/1.1 200 Just Fine\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
assert_eq!(
http_stream.resp_header().unwrap().get_reason_phrase(),
Some("Just Fine")
);
}
#[tokio::test]
async fn read_response_default() {
init_log();
let input_header = b"HTTP/1.1 200 OK\r\n\r\n";
let input_body = b"abc";
let input_close = b""; let mock_io = Builder::new()
.read(&input_header[..])
.read(&input_body[..])
.read(&input_close[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input_header.len(), res.unwrap());
let res = http_stream.read_body_ref().await.unwrap();
assert_eq!(res.unwrap(), input_body);
assert_eq!(
http_stream.body_reader.body_state,
ParseState::UntilClose(3)
);
let res = http_stream.read_body_ref().await.unwrap();
assert_eq!(res, None);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
}
#[tokio::test]
async fn body_bytes_received_content_length() {
init_log();
let input_header = b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\n";
let input_body = b"abc";
let input_close = b""; let mock_io = Builder::new()
.read(&input_header[..])
.read(&input_body[..])
.read(&input_close[..])
.build();
let mut http = HttpSession::new(Box::new(mock_io));
http.read_response().await.unwrap();
let _ = http.read_body_ref().await.unwrap();
let _ = http.read_body_ref().await.unwrap();
assert_eq!(http.body_bytes_received(), 3);
}
#[tokio::test]
async fn body_bytes_received_chunked() {
init_log();
let input_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n";
let input_body = b"3\r\nabc\r\n0\r\n\r\n";
let mock_io = Builder::new()
.read(&input_header[..])
.read(&input_body[..])
.build();
let mut http = HttpSession::new(Box::new(mock_io));
http.read_response().await.unwrap();
let first = http.read_body_ref().await.unwrap();
assert_eq!(first.unwrap(), b"abc");
let _ = http.read_body_ref().await.unwrap();
assert_eq!(http.body_bytes_received(), 3);
}
#[tokio::test]
async fn h1_body_bytes_received_http10_until_close() {
init_log();
let header = b"HTTP/1.1 200 OK\r\n\r\n";
let body = b"abc";
let close = b"";
let mock = Builder::new()
.read(&header[..])
.read(&body[..])
.read(&close[..])
.build();
let mut http = HttpSession::new(Box::new(mock));
http.read_response().await.unwrap();
let _ = http.read_body_ref().await.unwrap();
let _ = http.read_body_ref().await.unwrap();
assert_eq!(http.body_bytes_received(), 3);
}
#[tokio::test]
async fn h1_body_bytes_received_chunked_multi() {
init_log();
let header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n";
let body = b"1\r\na\r\n2\r\nbc\r\n0\r\n\r\n"; let mock = Builder::new().read(&header[..]).read(&body[..]).build();
let mut http = HttpSession::new(Box::new(mock));
http.read_response().await.unwrap();
let s1 = http.read_body_ref().await.unwrap().unwrap();
assert_eq!(s1, b"a");
let s2 = http.read_body_ref().await.unwrap().unwrap();
assert_eq!(s2, b"bc");
let _ = http.read_body_ref().await.unwrap();
assert_eq!(http.body_bytes_received(), 3);
}
#[tokio::test]
async fn h1_body_bytes_received_preread_in_header_buf() {
init_log();
let combined = b"HTTP/1.1 200 OK\r\n\r\nabc";
let close = b"";
let mock = Builder::new().read(&combined[..]).read(&close[..]).build();
let mut http = HttpSession::new(Box::new(mock));
http.read_response().await.unwrap();
let s = http.read_body_ref().await.unwrap().unwrap();
assert_eq!(s, b"abc");
let _ = http.read_body_ref().await.unwrap();
assert_eq!(http.body_bytes_received(), 3);
}
#[tokio::test]
async fn h1_body_bytes_received_overread_content_length() {
init_log();
let header1 = b"HTTP/1.1 200 OK\r\n";
let header2 = b"Content-Length: 2\r\n\r\n";
let body = b"abc"; let mock = Builder::new()
.read(&header1[..])
.read(&header2[..])
.read(&body[..])
.build();
let mut http = HttpSession::new(Box::new(mock));
http.read_response().await.unwrap();
let s = http.read_body_ref().await.unwrap().unwrap();
assert_eq!(s, b"ab");
let _ = http.read_body_ref().await.unwrap();
assert_eq!(http.body_bytes_received(), 2);
}
#[tokio::test]
async fn h1_body_bytes_received_after_100_continue() {
init_log();
let info = b"HTTP/1.1 100 Continue\r\n\r\n";
let header = b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\n";
let body = b"x";
let mock = Builder::new()
.read(&info[..])
.read(&header[..])
.read(&body[..])
.build();
let mut http = HttpSession::new(Box::new(mock));
match http.read_response_task().await.unwrap() {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 100);
assert!(!eob);
}
_ => panic!("expected informational header"),
}
match http.read_response_task().await.unwrap() {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 200);
assert!(!eob);
}
_ => panic!("expected final header"),
}
let s = http.read_body_ref().await.unwrap().unwrap();
assert_eq!(s, b"x");
let _ = http.read_body_ref().await.unwrap();
assert_eq!(http.body_bytes_received(), 1);
}
#[tokio::test]
async fn read_response_overread() {
init_log();
let input_header = b"HTTP/1.1 200 OK\r\n";
let input_header2 = b"Content-Length: 2\r\n\r\n";
let input_body = b"abc";
let mock_io = Builder::new()
.read(&input_header[..])
.read(&input_header2[..])
.read(&input_body[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input_header.len() + input_header2.len(), res.unwrap());
let res = http_stream.read_body_ref().await.unwrap();
assert_eq!(res.unwrap(), &input_body[..2]);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(2));
let res = http_stream.read_body_ref().await.unwrap();
assert_eq!(res, None);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(2));
http_stream.respect_keepalive();
assert!(!http_stream.will_keepalive());
}
#[tokio::test]
async fn read_resp_header_with_space() {
init_log();
let input = b"HTTP/1.1 200 OK\r\nServer : pingora\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
assert_eq!(1, http_stream.resp_header().unwrap().headers.len());
assert_eq!(http_stream.get_header("Server").unwrap(), "pingora");
}
#[cfg(feature = "patched_http1")]
#[tokio::test]
async fn read_resp_header_with_utf8() {
init_log();
let input = "HTTP/1.1 200 OK\r\nServer👍: pingora\r\n\r\n".as_bytes();
let mock_io = Builder::new().read(input).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let resp = http_stream.read_resp_header_parts().await.unwrap();
assert_eq!(1, http_stream.resp_header().unwrap().headers.len());
assert_eq!(http_stream.get_header("Server👍").unwrap(), "pingora");
assert_eq!(resp.headers.get("Server👍").unwrap(), "pingora");
}
#[tokio::test]
#[should_panic(expected = "There is still data left to read.")]
async fn read_timeout() {
init_log();
let input = b"HTTP/1.1 200 OK\r\n\r\n";
let mock_io = Builder::new()
.wait(Duration::from_secs(2))
.read(&input[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_timeout = Some(Duration::from_secs(1));
let res = http_stream.read_response().await;
assert_eq!(res.unwrap_err().etype(), &ErrorType::ReadTimedout);
}
#[tokio::test]
async fn read_2_buf() {
init_log();
let input1 = b"HTTP/1.1 200 OK\r\n";
let input2 = b"Server: pingora\r\n\r\n";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input1.len() + input2.len(), res.unwrap());
assert_eq!(
input1.len() + input2.len(),
http_stream.get_headers_raw().len()
);
assert_eq!(1, http_stream.resp_header().unwrap().headers.len());
assert_eq!(http_stream.get_header("Server").unwrap(), "pingora");
assert_eq!(Some(StatusCode::OK), http_stream.get_status());
assert_eq!(Version::HTTP_11, http_stream.resp_header().unwrap().version);
}
#[tokio::test]
#[should_panic(expected = "There is still data left to read.")]
async fn read_invalid() {
let input1 = b"HTP/1.1 200 OK\r\n";
let input2 = b"Server: pingora\r\n\r\n";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(&ErrorType::InvalidHTTPHeader, res.unwrap_err().etype());
}
#[tokio::test]
async fn write() {
let wire = b"GET /test HTTP/1.1\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().write(wire).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut new_request = RequestHeader::build("GET", b"/test", None).unwrap();
new_request.insert_header("Foo", "Bar").unwrap();
let n = http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(wire.len(), n);
}
#[rstest]
#[case::negative("-1")]
#[case::not_a_number("abc")]
#[case::float("1.5")]
#[case::empty("")]
#[case::spaces(" ")]
#[case::mixed("123abc")]
#[tokio::test]
async fn validate_response_rejects_invalid_content_length(#[case] invalid_value: &str) {
init_log();
let input = format!(
"HTTP/1.1 200 OK\r\nServer: test\r\nContent-Length: {}\r\n\r\n",
invalid_value
);
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert!(res.is_err());
assert_eq!(res.unwrap_err().etype(), &ErrorType::InvalidHTTPHeader);
}
#[tokio::test]
async fn allow_invalid_content_length_close_delimited_when_configured() {
init_log();
let input_header = b"HTTP/1.1 200 OK\r\nServer: test\r\nContent-Length: abc\r\n\r\n";
let input_body = b"abc";
let input_close = b"";
let mock_io = Builder::new()
.read(&input_header[..])
.read(&input_body[..])
.read(&input_close[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut peer_options = PeerOptions::new();
peer_options.allow_h1_response_invalid_content_length = true;
http_stream.set_allow_h1_response_invalid_content_length(
peer_options.allow_h1_response_invalid_content_length,
);
let res = http_stream.read_response().await;
assert!(res.is_ok());
let body = http_stream.read_body_ref().await.unwrap().unwrap();
assert_eq!(body, input_body);
assert_eq!(
http_stream.body_reader.body_state,
ParseState::UntilClose(3)
);
let body = http_stream.read_body_ref().await.unwrap();
assert!(body.is_none());
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
}
#[rstest]
#[case::valid_zero("0")]
#[case::valid_small("123")]
#[case::valid_large("999999")]
#[tokio::test]
async fn validate_response_accepts_valid_content_length(#[case] valid_value: &str) {
init_log();
let input = format!(
"HTTP/1.1 200 OK\r\nServer: test\r\nContent-Length: {}\r\n\r\n",
valid_value
);
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert!(res.is_ok());
}
#[tokio::test]
async fn validate_response_accepts_no_content_length() {
init_log();
let input = b"HTTP/1.1 200 OK\r\nServer: test\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert!(res.is_ok());
}
#[rstest]
#[case(None, None, None)]
#[case(Some("transfer-encoding"), None, None)]
#[case(Some("transfer-encoding"), Some("CONTENT-LENGTH"), Some("4"))]
#[case(Some("TRANSFER-ENCODING"), Some("CONTENT-LENGTH"), Some("4"))]
#[case(Some("TRANSFER-ENCODING"), None, None)]
#[case(None, Some("CONTENT-LENGTH"), Some("4"))]
#[case(Some("TRANSFER-ENCODING"), Some("content-length"), Some("4"))]
#[case(None, Some("content-length"), Some("4"))]
#[case(Some("TRANSFER-ENCODING"), Some("CONTENT-LENGTH"), Some("abc"))]
#[tokio::test]
async fn response_transfer_encoding_and_content_length_handling(
#[case] transfer_encoding_header: Option<&str>,
#[case] content_length_header: Option<&str>,
#[case] content_length_value: Option<&str>,
) {
init_log();
let input1 = b"HTTP/1.1 200 OK\r\n";
let mut input2 = "Server: test\r\n".to_owned();
if let Some(transfer_encoding) = transfer_encoding_header {
input2 += &format!("{transfer_encoding}: chunked\r\n");
}
if let Some(content_length) = content_length_header {
let value = content_length_value.unwrap_or("4");
input2 += &format!("{content_length}: {value}\r\n")
}
input2 += "\r\n";
let mock_io = Builder::new()
.read(&input1[..])
.read(input2.as_bytes())
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let _ = http_stream.read_response().await.unwrap();
match (content_length_header, transfer_encoding_header) {
(Some(_) | None, Some(_)) => {
assert!(http_stream.get_header(header::TRANSFER_ENCODING).is_some());
assert!(http_stream.get_header(header::CONTENT_LENGTH).is_none());
}
(Some(_), None) => {
assert!(http_stream.get_header(header::TRANSFER_ENCODING).is_none());
assert!(http_stream.get_header(header::CONTENT_LENGTH).is_some());
}
_ => {
assert!(http_stream.get_header(header::CONTENT_LENGTH).is_none());
assert!(http_stream.get_header(header::TRANSFER_ENCODING).is_none());
}
}
}
#[tokio::test]
#[should_panic(expected = "There is still data left to write.")]
async fn write_timeout() {
let wire = b"GET /test HTTP/1.1\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new()
.wait(Duration::from_secs(2))
.write(wire)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.write_timeout = Some(Duration::from_secs(1));
let mut new_request = RequestHeader::build("GET", b"/test", None).unwrap();
new_request.insert_header("Foo", "Bar").unwrap();
let res = http_stream
.write_request_header(Box::new(new_request))
.await;
assert_eq!(res.unwrap_err().etype(), &ErrorType::WriteTimedout);
}
#[tokio::test]
#[should_panic(expected = "There is still data left to write.")]
async fn write_body_timeout() {
let header = b"POST /test HTTP/1.1\r\nContent-Length: 3\r\n\r\n";
let body = b"abc";
let mock_io = Builder::new()
.write(&header[..])
.wait(Duration::from_secs(2))
.write(&body[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.write_timeout = Some(Duration::from_secs(1));
let mut new_request = RequestHeader::build("POST", b"/test", None).unwrap();
new_request.insert_header("Content-Length", "3").unwrap();
http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
let res = http_stream.write_body(body).await;
assert_eq!(res.unwrap_err().etype(), &WriteTimedout);
}
#[cfg(feature = "patched_http1")]
#[tokio::test]
async fn write_invalid_path() {
let wire = b"GET /\x01\xF0\x90\x80 HTTP/1.1\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().write(wire).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut new_request = RequestHeader::build("GET", b"/\x01\xF0\x90\x80", None).unwrap();
new_request.insert_header("Foo", "Bar").unwrap();
let n = http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(wire.len(), n);
}
#[tokio::test]
async fn read_informational() {
init_log();
let input1 = b"HTTP/1.1 100 Continue\r\n\r\n";
let input2 = b"HTTP/1.1 204 OK\r\nServer: pingora\r\n\r\n";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 100);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 204);
assert!(eob);
}
_ => {
panic!("task should be header")
}
}
}
#[tokio::test]
async fn read_informational_combined_with_final() {
init_log();
let input = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nServer: pingora\r\nContent-Length: 3\r\n\r\n";
let body = b"abc";
let mock_io = Builder::new().read(&input[..]).read(&body[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 100);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 200);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Body(b, eob) => {
assert_eq!(b.unwrap(), &body[..]);
assert!(eob);
}
_ => {
panic!("task {task:?} should be body")
}
}
}
#[tokio::test]
async fn read_informational_multiple_combined_with_final() {
init_log();
let input = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 103 Early Hints\r\n\r\nHTTP/1.1 204 No Content\r\nServer: pingora\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 100);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 103);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 204);
assert!(eob);
}
_ => {
panic!("task should be header")
}
}
}
#[tokio::test]
async fn read_informational_then_keepalive_response() {
init_log();
let wire = b"GET / HTTP/1.1\r\n\r\n";
let input1 = b"HTTP/1.1 100 Continue\r\n\r\n";
let input2 = b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\n"; let body = b"response body";
let mock_io = Builder::new()
.write(&wire[..])
.read(&input1[..])
.read(&input2[..])
.read(&body[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let new_request = RequestHeader::build("GET", b"/", None).unwrap();
http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 100);
assert!(!eob);
}
_ => {
panic!("task should be informational header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 200);
assert!(!eob); }
_ => {
panic!("task should be final header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Body(b, eob) => {
assert_eq!(b.unwrap(), &body[..]);
assert!(eob); }
_ => {
panic!("task {task:?} should be body")
}
}
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(13));
http_stream.respect_keepalive();
assert!(http_stream.will_keepalive());
}
#[tokio::test]
async fn init_body_for_upgraded_req() {
let wire =
b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 0\r\n\r\n";
let input1 = b"HTTP/1.1 101 Switching Protocols\r\n\r\n";
let input2 = b"PAYLOAD";
let ws_data = b"data";
let mock_io = Builder::new()
.write(wire)
.read(&input1[..])
.write(&ws_data[..])
.read(&input2[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut new_request = RequestHeader::build("GET", b"/", None).unwrap();
new_request.insert_header("Connection", "Upgrade").unwrap();
new_request.insert_header("Upgrade", "WS").unwrap();
new_request.insert_header("Content-Length", "0").unwrap();
let _ = http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(0, 0)
);
assert!(http_stream.body_writer.finished());
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 101);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
assert_eq!(
http_stream.body_reader.body_state,
ParseState::UntilClose(0)
);
assert!(http_stream.body_writer.finished());
http_stream.maybe_upgrade_body_writer();
assert!(!http_stream.body_writer.finished());
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0));
http_stream.write_body(&ws_data[..]).await.unwrap();
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::UpgradedBody(b, eob) => {
assert_eq!(b.unwrap(), &input2[..]);
assert!(!eob);
}
_ => {
panic!("task should be upgraded body")
}
}
}
#[tokio::test]
async fn init_preread_body_for_upgraded_req() {
let wire =
b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 0\r\n\r\n";
let input = b"HTTP/1.1 101 Switching Protocols\r\n\r\nPAYLOAD";
let ws_data = b"data";
let mock_io = Builder::new()
.write(wire)
.read(&input[..])
.write(&ws_data[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut new_request = RequestHeader::build("GET", b"/", None).unwrap();
new_request.insert_header("Connection", "Upgrade").unwrap();
new_request.insert_header("Upgrade", "WS").unwrap();
new_request.insert_header("Content-Length", "0").unwrap();
let _ = http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(0, 0)
);
assert!(http_stream.body_writer.finished());
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 101);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
assert_eq!(
http_stream.body_reader.body_state,
ParseState::UntilClose(0)
);
assert!(http_stream.body_writer.finished());
http_stream.maybe_upgrade_body_writer();
assert!(!http_stream.body_writer.finished());
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0));
http_stream.write_body(&ws_data[..]).await.unwrap();
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::UpgradedBody(b, eob) => {
assert_eq!(b.unwrap(), &b"PAYLOAD"[..]);
assert!(!eob);
}
_ => {
panic!("task should be upgraded body")
}
}
}
#[tokio::test]
async fn read_body_eos_after_upgrade() {
let wire =
b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 10\r\n\r\n";
let input1 = b"HTTP/1.1 101 Switching Protocols\r\n\r\n";
let input2 = b"PAYLOAD";
let body_data = b"0123456789";
let ws_data = b"data";
let mock_io = Builder::new()
.write(wire)
.read(&input1[..])
.write(&body_data[..])
.read(&input2[..])
.write(&ws_data[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut new_request = RequestHeader::build("GET", b"/", None).unwrap();
new_request.insert_header("Connection", "Upgrade").unwrap();
new_request.insert_header("Upgrade", "WS").unwrap();
new_request.insert_header("Content-Length", "10").unwrap();
let _ = http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(10, 0)
);
assert!(!http_stream.body_writer.finished());
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 101);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
assert_eq!(
http_stream.body_reader.body_state,
ParseState::UntilClose(0)
);
http_stream.write_body(&body_data[..]).await.unwrap();
http_stream.finish_body().await.unwrap();
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::UpgradedBody(b, eob) => {
assert_eq!(b.unwrap(), &input2[..]);
assert!(!eob);
}
t => {
panic!("task {t:?} should be upgraded body")
}
}
assert!(http_stream.body_writer.finished());
http_stream.maybe_upgrade_body_writer();
assert!(!http_stream.body_writer.finished());
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0));
http_stream.write_body(&ws_data[..]).await.unwrap();
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(4));
http_stream.finish_body().await.unwrap();
}
#[tokio::test]
async fn read_switching_protocol() {
init_log();
let wire =
b"GET / HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: WS\r\nContent-Length: 0\r\n\r\n";
let input1 = b"HTTP/1.1 101 Continue\r\n\r\n";
let input2 = b"PAYLOAD";
let mock_io = Builder::new()
.write(&wire[..])
.read(&input1[..])
.read(&input2[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut new_request = RequestHeader::build("GET", b"/", None).unwrap();
new_request.insert_header("Connection", "Upgrade").unwrap();
new_request.insert_header("Upgrade", "WS").unwrap();
new_request.insert_header("Content-Length", "0").unwrap();
let _ = http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(0, 0)
);
assert!(http_stream.body_writer.finished());
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::Header(h, eob) => {
assert_eq!(h.status, 101);
assert!(!eob);
}
_ => {
panic!("task should be header")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::UpgradedBody(b, eob) => {
assert_eq!(b.unwrap(), &input2[..]);
assert!(!eob);
}
_ => {
panic!("task should be upgraded body")
}
}
let task = http_stream.read_response_task().await.unwrap();
match task {
HttpTask::UpgradedBody(b, eob) => {
assert!(b.is_none());
assert!(eob);
}
_ => {
panic!("task should be body with end of stream")
}
}
}
#[tokio::test]
async fn read_obsolete_multiline_headers() {
init_log();
let input = b"HTTP/1.1 200 OK\r\nServer : pingora\r\n Foo: Bar\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
assert_eq!(1, http_stream.resp_header().unwrap().headers.len());
assert_eq!(
http_stream.get_header("Server").unwrap(),
"pingora Foo: Bar"
);
let input = b"HTTP/1.1 200 OK\r\nServer : pingora\r\n\t Fizz: Buzz\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
assert_eq!(1, http_stream.resp_header().unwrap().headers.len());
assert_eq!(
http_stream.get_header("Server").unwrap(),
"pingora \t Fizz: Buzz"
);
}
#[cfg(feature = "patched_http1")]
#[tokio::test]
async fn read_headers_skip_invalid_line() {
init_log();
let input = b"HTTP/1.1 200 OK\r\n;\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
assert_eq!(1, http_stream.resp_header().unwrap().headers.len());
assert_eq!(http_stream.get_header("Foo").unwrap(), "Bar");
}
#[tokio::test]
async fn read_keepalive_headers() {
init_log();
async fn build_resp_with_keepalive(conn: &str) -> HttpSession {
let input =
format!("HTTP/1.1 200 OK\r\nConnection: {conn}\r\nContent-Length: 0\r\n\r\n");
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
http_stream.respect_keepalive();
http_stream
}
assert_eq!(
build_resp_with_keepalive("close").await.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("keep-alive")
.await
.keepalive_timeout,
KeepaliveStatus::Infinite
);
assert_eq!(
build_resp_with_keepalive("foo").await.keepalive_timeout,
KeepaliveStatus::Infinite
);
assert_eq!(
build_resp_with_keepalive("upgrade,close")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("upgrade, close")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("Upgrade, close")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("Upgrade,close")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("close,upgrade")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("close, upgrade")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("close,Upgrade")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
assert_eq!(
build_resp_with_keepalive("close, Upgrade")
.await
.keepalive_timeout,
KeepaliveStatus::Off
);
async fn build_resp_with_keepalive_values(keep_alive: &str) -> HttpSession {
let input = format!("HTTP/1.1 200 OK\r\nKeep-Alive: {keep_alive}\r\n\r\n");
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_response().await;
assert_eq!(input.len(), res.unwrap());
http_stream.respect_keepalive();
http_stream
}
assert_eq!(
build_resp_with_keepalive_values("timeout=5, max=1000")
.await
.get_keepalive_values(),
(Some(5), Some(1000))
);
assert_eq!(
build_resp_with_keepalive_values("max=1000, timeout=5")
.await
.get_keepalive_values(),
(Some(5), Some(1000))
);
assert_eq!(
build_resp_with_keepalive_values(" timeout = 5, max = 1000 ")
.await
.get_keepalive_values(),
(Some(5), Some(1000))
);
assert_eq!(
build_resp_with_keepalive_values("timeout=5")
.await
.get_keepalive_values(),
(Some(5), None)
);
assert_eq!(
build_resp_with_keepalive_values("max=1000")
.await
.get_keepalive_values(),
(None, Some(1000))
);
assert_eq!(
build_resp_with_keepalive_values("a=b")
.await
.get_keepalive_values(),
(None, None)
);
assert_eq!(
build_resp_with_keepalive_values("")
.await
.get_keepalive_values(),
(None, None)
);
}
#[tokio::test]
async fn test_http10_response_with_transfer_encoding_disables_keepalive() {
let input = b"HTTP/1.0 200 OK\r\n\
Transfer-Encoding: chunked\r\n\
Connection: keep-alive\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_response().await.unwrap();
http_stream.respect_keepalive();
assert!(!http_stream.will_keepalive());
assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Off);
}
#[tokio::test]
async fn test_http11_response_with_transfer_encoding_allows_keepalive() {
let input = b"HTTP/1.1 200 OK\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_response().await.unwrap();
http_stream.respect_keepalive();
assert!(http_stream.will_keepalive());
}
#[tokio::test]
async fn test_response_multiple_transfer_encoding_headers() {
init_log();
let input = b"HTTP/1.1 200 OK\r\n\
Transfer-Encoding: gzip\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_response().await.unwrap();
assert!(http_stream.is_chunked_encoding());
let body = http_stream.read_body_bytes().await.unwrap();
assert_eq!(body.as_ref().unwrap().as_ref(), b"hello");
http_stream.finish_body().await.unwrap();
}
#[tokio::test]
async fn test_response_multiple_te_headers_chunked_not_last() {
init_log();
let input = b"HTTP/1.1 200 OK\r\n\
Transfer-Encoding: chunked\r\n\
Transfer-Encoding: identity\r\n\
Content-Length: 5\r\n\
\r\n\
hello";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_response().await.unwrap();
assert!(!http_stream.is_chunked_encoding());
}
#[test]
fn test_is_chunked_encoding_before_response() {
let mock_io = Builder::new().build();
let http_stream = HttpSession::new(Box::new(mock_io));
assert!(!http_stream.is_chunked_encoding());
}
#[tokio::test]
async fn write_request_body_implicit_zero_content_length() {
init_log();
let header = b"POST /test HTTP/1.1\r\n\r\n";
let mock_io = Builder::new().write(&header[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let new_request = RequestHeader::build("POST", b"/test", None).unwrap();
http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(0, 0)
);
}
#[tokio::test]
async fn write_request_body_with_content_length() {
init_log();
let header = b"POST /test HTTP/1.1\r\nContent-Length: 3\r\n\r\n";
let body = b"abc";
let mock_io = Builder::new().write(&header[..]).write(&body[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let mut new_request = RequestHeader::build("POST", b"/test", None).unwrap();
new_request.insert_header("Content-Length", "3").unwrap();
http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(3, 0)
);
http_stream.write_body(body).await.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(3, 3)
);
}
#[tokio::test]
async fn close_delimited_response_explicitly_disables_keepalive() {
init_log();
let wire = b"GET / HTTP/1.1\r\n\r\n";
let input_header = b"HTTP/1.1 200 OK\r\n\r\n";
let input_body = b"abc";
let input_close = b""; let mock_io = Builder::new()
.write(&wire[..])
.read(&input_header[..])
.read(&input_body[..])
.read(&input_close[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let new_request = RequestHeader::build("GET", b"/", None).unwrap();
http_stream
.write_request_header(Box::new(new_request))
.await
.unwrap();
http_stream.read_response().await.unwrap();
http_stream.read_body_ref().await.unwrap();
assert_eq!(
http_stream.body_reader.body_state,
ParseState::UntilClose(3)
);
let res2 = http_stream.read_body_ref().await.unwrap();
assert!(res2.is_none());
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
http_stream.respect_keepalive();
assert!(!http_stream.will_keepalive());
}
}
#[cfg(test)]
mod test_sync {
use super::*;
use log::error;
#[test]
fn test_request_to_wire() {
let mut new_request = RequestHeader::build("GET", b"/", None).unwrap();
new_request.insert_header("Foo", "Bar").unwrap();
let wire = http_req_header_to_wire(&new_request).unwrap();
let mut headers = [httparse::EMPTY_HEADER; 128];
let mut req = httparse::Request::new(&mut headers);
let result = req.parse(wire.as_ref());
match result {
Ok(_) => {}
Err(e) => error!("{:?}", e),
}
assert!(result.unwrap().is_complete());
assert_eq!("/", req.path.unwrap());
assert_eq!(b"Foo", headers[0].name.as_bytes());
assert_eq!(b"Bar", headers[0].value);
}
}