use bstr::ByteSlice;
use bytes::Bytes;
use bytes::{BufMut, BytesMut};
use http::header::{CONTENT_LENGTH, TRANSFER_ENCODING};
use http::HeaderValue;
use http::{header, header::AsHeaderName, Method, Version};
use log::{debug, trace, warn};
use once_cell::sync::Lazy;
use percent_encoding::{percent_encode, AsciiSet, CONTROLS};
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_http::{IntoCaseHeaderName, RequestHeader, ResponseHeader};
use pingora_timeout::timeout;
use regex::bytes::Regex;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::body::{BodyReader, BodyWriter};
use super::common::*;
use crate::protocols::http::{body_buffer::FixedBuffer, date, HttpTask};
use crate::protocols::{Digest, SocketAddr, Stream};
use crate::utils::{BufRef, KVRef};
pub struct HttpSession {
underlying_stream: Stream,
buf: Bytes,
raw_header: Option<BufRef>,
preread_body: Option<BufRef>,
body_reader: BodyReader,
body_writer: BodyWriter,
body_write_buf: BytesMut,
body_bytes_sent: usize,
body_bytes_read: usize,
update_resp_headers: bool,
keepalive_timeout: KeepaliveStatus,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
total_drain_timeout: Option<Duration>,
response_written: Option<Box<ResponseHeader>>,
request_header: Option<Box<RequestHeader>>,
retry_buffer: Option<FixedBuffer>,
upgraded: bool,
digest: Box<Digest>,
min_send_rate: Option<usize>,
ignore_info_resp: bool,
close_on_response_before_downstream_finish: bool,
keepalive_reuses_remaining: Option<u32>,
}
impl HttpSession {
pub fn new(underlying_stream: Stream) -> Self {
let digest = Box::new(Digest {
ssl_digest: underlying_stream.get_ssl_digest(),
timing_digest: underlying_stream.get_timing_digest(),
proxy_digest: underlying_stream.get_proxy_digest(),
socket_digest: underlying_stream.get_socket_digest(),
});
HttpSession {
underlying_stream,
buf: Bytes::new(), raw_header: None,
preread_body: None,
body_reader: BodyReader::new(false),
body_writer: BodyWriter::new(),
body_write_buf: BytesMut::new(),
keepalive_timeout: KeepaliveStatus::Off,
update_resp_headers: true,
response_written: None,
request_header: None,
read_timeout: Some(Duration::from_secs(60)),
write_timeout: None,
total_drain_timeout: None,
body_bytes_sent: 0,
body_bytes_read: 0,
retry_buffer: None,
upgraded: false,
digest,
min_send_rate: None,
ignore_info_resp: false,
close_on_response_before_downstream_finish: true,
keepalive_reuses_remaining: None,
}
}
pub async fn read_request(&mut self) -> Result<Option<usize>> {
const MAX_ERR_BUF_LEN: usize = 2048;
self.buf.clear();
let mut buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE);
let mut already_read: usize = 0;
loop {
if already_read > MAX_HEADER_SIZE {
return Error::e_explain(
InvalidHTTPHeader,
format!("Request header larger than {MAX_HEADER_SIZE}"),
);
}
let read_result = {
let read_event = self.underlying_stream.read_buf(&mut buf);
match self.keepalive_timeout {
KeepaliveStatus::Timeout(d) => match timeout(d, read_event).await {
Ok(res) => res,
Err(e) => {
debug!("keepalive timeout {d:?} reached, {e}");
return Ok(None);
}
},
KeepaliveStatus::Infinite => {
read_event.await
}
KeepaliveStatus::Off => match self.read_timeout {
Some(t) => match timeout(t, read_event).await {
Ok(res) => res,
Err(e) => {
debug!("read timeout {t:?} reached, {e}");
return Error::e_explain(ReadTimedout, format!("timeout: {t:?}"));
}
},
None => read_event.await,
},
}
};
let n = match read_result {
Ok(n_read) => {
if n_read == 0 {
if already_read > 0 {
return Error::e_explain(
ConnectionClosed,
format!(
"while reading request headers, bytes already read: {}",
already_read
),
);
} else {
debug!("Client prematurely closed connection with 0 byte sent");
return Ok(None);
}
}
n_read
}
Err(e) => {
if already_read > 0 {
return Error::e_because(ReadError, "while reading request headers", e);
}
return Ok(None);
}
};
already_read += n;
loop {
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut headers);
let parsed = parse_req_buffer(&mut req, &buf);
match parsed {
HeaderParseState::Complete(s) => {
self.raw_header = Some(BufRef(0, s));
self.preread_body = Some(BufRef(s, already_read));
let base = buf.as_ptr() as usize;
let mut header_refs = Vec::<KVRef>::with_capacity(req.headers.len());
let _num_headers = populate_headers(base, &mut header_refs, req.headers);
let mut request_header = Box::new(RequestHeader::build(
req.method.unwrap_or(""),
req.path.unwrap_or("").as_bytes(),
Some(req.headers.len()),
)?);
request_header.set_version(match req.version {
Some(1) => Version::HTTP_11,
Some(0) => Version::HTTP_10,
_ => Version::HTTP_09,
});
let buf = buf.freeze();
for header in header_refs {
let header_name = header.get_name_bytes(&buf);
let header_name = header_name.into_case_header_name();
let value_bytes = header.get_value_bytes(&buf);
let header_value = unsafe {
http::HeaderValue::from_maybe_shared_unchecked(value_bytes)
};
request_header
.append_header(header_name, header_value)
.or_err(InvalidHTTPHeader, "while parsing request header")?;
}
let contains_transfer_encoding =
request_header.headers.contains_key(TRANSFER_ENCODING);
let contains_content_length =
request_header.headers.contains_key(CONTENT_LENGTH);
let has_both_te_and_cl =
contains_content_length && contains_transfer_encoding;
if has_both_te_and_cl {
request_header.remove_header(&CONTENT_LENGTH);
}
self.buf = buf;
self.request_header = Some(request_header);
self.body_reader.reinit();
self.response_written = None;
self.respect_keepalive();
if has_both_te_and_cl {
self.set_keepalive(None);
}
self.validate_request()?;
return Ok(Some(s));
}
HeaderParseState::Partial => {
break;
}
HeaderParseState::Invalid(e) => match e {
httparse::Error::Token | httparse::Error::Version => {
if let Some(new_buf) = escape_illegal_request_line(&buf) {
buf = new_buf;
already_read = buf.len();
} else {
debug!("Invalid request header from {:?}", self.underlying_stream);
buf.truncate(MAX_ERR_BUF_LEN);
return Error::e_because(
InvalidHTTPHeader,
format!("buf: {}", buf.escape_ascii()),
e,
);
}
}
_ => {
debug!("Invalid request header from {:?}", self.underlying_stream);
buf.truncate(MAX_ERR_BUF_LEN);
return Error::e_because(
InvalidHTTPHeader,
format!("buf: {:?}", buf.as_bstr()),
e,
);
}
},
}
}
}
}
pub fn validate_request(&self) -> Result<()> {
let req_header = self.req_header();
super::common::check_dup_content_length(&req_header.headers)?;
if req_header.headers.contains_key(TRANSFER_ENCODING) {
if req_header.version == http::Version::HTTP_10 {
return Error::e_explain(
InvalidHTTPHeader,
"HTTP/1.0 requests cannot include Transfer-Encoding header",
);
}
if !self.is_chunked_encoding() {
return Error::e_explain(InvalidHTTPHeader, "non-chunked final Transfer-Encoding");
}
}
self.get_content_length()?;
Ok(())
}
pub fn req_header(&self) -> &RequestHeader {
self.request_header
.as_ref()
.expect("Request header is not read yet")
}
pub fn req_header_mut(&mut self) -> &mut RequestHeader {
self.request_header
.as_mut()
.expect("Request header is not read yet")
}
pub fn get_header(&self, name: impl AsHeaderName) -> Option<&HeaderValue> {
self.request_header
.as_ref()
.and_then(|h| h.headers.get(name))
}
pub(crate) fn get_method(&self) -> Option<&http::Method> {
self.request_header.as_ref().map(|r| &r.method)
}
pub(crate) fn get_path(&self) -> &[u8] {
self.request_header.as_ref().map_or(b"", |r| r.raw_path())
}
pub(crate) fn get_host(&self) -> &[u8] {
self.request_header
.as_ref()
.and_then(|h| h.headers.get(header::HOST))
.map_or(b"", |h| h.as_bytes())
}
pub fn request_summary(&self) -> String {
format!(
"{} {}, Host: {}",
self.get_method().map_or("-", |r| r.as_str()),
String::from_utf8_lossy(self.get_path()),
String::from_utf8_lossy(self.get_host())
)
}
pub fn is_upgrade_req(&self) -> bool {
match self.request_header.as_deref() {
Some(req) => is_upgrade_req(req),
None => false,
}
}
pub fn get_header_bytes(&self, name: impl AsHeaderName) -> &[u8] {
self.get_header(name).map_or(b"", |v| v.as_bytes())
}
pub async fn read_body_bytes(&mut self) -> Result<Option<Bytes>> {
let read = self.read_body().await?;
Ok(read.map(|b| {
let bytes = Bytes::copy_from_slice(self.get_body(&b));
self.body_bytes_read += bytes.len();
if let Some(buffer) = self.retry_buffer.as_mut() {
buffer.write_to_buffer(&bytes);
}
bytes
}))
}
async fn do_read_body(&mut self) -> Result<Option<BufRef>> {
self.init_body_reader();
self.body_reader
.read_body(&mut self.underlying_stream)
.await
}
async fn read_body(&mut self) -> Result<Option<BufRef>> {
match self.read_timeout {
Some(t) => match timeout(t, self.do_read_body()).await {
Ok(res) => res,
Err(_) => Error::e_explain(ReadTimedout, format!("reading body, timeout: {t:?}")),
},
None => self.do_read_body().await,
}
}
async fn do_drain_request_body(&mut self) -> Result<()> {
loop {
match self.read_body_bytes().await {
Ok(Some(_)) => { }
Ok(None) => return Ok(()), Err(e) => return Err(e),
}
}
}
pub async fn drain_request_body(&mut self) -> Result<()> {
if self.is_body_done() {
return Ok(());
}
match self.total_drain_timeout {
Some(t) => match timeout(t, self.do_drain_request_body()).await {
Ok(res) => res,
Err(_) => Error::e_explain(ReadTimedout, format!("draining body, timeout: {t:?}")),
},
None => self.do_drain_request_body().await,
}
}
pub fn is_body_done(&mut self) -> bool {
self.init_body_reader();
self.body_reader.body_done()
}
pub fn is_body_empty(&mut self) -> bool {
self.init_body_reader();
self.body_reader.body_empty()
}
pub async fn write_response_header(&mut self, mut header: Box<ResponseHeader>) -> Result<()> {
if header.status.is_informational() && self.ignore_info_resp(header.status.into()) {
debug!("ignoring informational headers");
return Ok(());
}
if let Some(resp) = self.response_written.as_ref() {
if !resp.status.is_informational() || self.upgraded {
warn!("Respond header is already sent, cannot send again");
return Ok(());
}
}
if self.close_on_response_before_downstream_finish
&& (self.request_header.is_none() || !self.is_body_done())
{
debug!("set connection close before downstream finish");
self.set_keepalive(None);
}
if !header.status.is_informational() && self.update_resp_headers {
header.insert_header(header::DATE, date::get_cached_date())?;
let connection_value = if self.will_keepalive() {
"keep-alive"
} else {
"close"
};
header.insert_header(header::CONNECTION, connection_value)?;
}
if header.status == 101 {
self.set_keepalive(None);
}
if header.status == 101 || !header.status.is_informational() {
if let Some(upgrade_ok) = self.is_upgrade(&header) {
if upgrade_ok {
debug!("ok upgrade handshake");
self.upgraded = true;
if self.body_reader.need_init() {
self.init_body_reader();
} else {
self.body_reader.convert_to_close_delimited();
}
} else {
debug!("bad upgrade handshake!");
}
}
self.init_body_writer(&header);
}
if self.body_writer.is_close_delimited() {
self.set_keepalive(None);
}
let flush = header.status.is_informational()
|| header.headers.get(header::CONTENT_LENGTH).is_none();
let mut write_buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE);
http_resp_header_to_buf(&header, &mut write_buf).unwrap();
match self.underlying_stream.write_all(&write_buf).await {
Ok(()) => {
if flush || self.body_writer.finished() {
self.underlying_stream
.flush()
.await
.or_err(WriteError, "flushing response header")?;
}
self.response_written = Some(header);
self.body_bytes_sent += write_buf.len();
Ok(())
}
Err(e) => Error::e_because(WriteError, "writing response header", e),
}
}
pub fn response_written(&self) -> Option<&ResponseHeader> {
self.response_written.as_deref()
}
pub fn is_upgrade(&self, header: &ResponseHeader) -> Option<bool> {
if self.is_upgrade_req() {
Some(is_upgrade_resp(header))
} else {
None
}
}
pub fn was_upgraded(&self) -> bool {
self.upgraded
}
fn set_keepalive(&mut self, seconds: Option<u64>) {
match seconds {
Some(sec) => {
if sec > 0 {
self.keepalive_timeout = KeepaliveStatus::Timeout(Duration::from_secs(sec));
} else {
self.keepalive_timeout = KeepaliveStatus::Infinite;
}
}
None => {
self.keepalive_timeout = KeepaliveStatus::Off;
}
}
}
pub fn get_keepalive_timeout(&self) -> Option<u64> {
match self.keepalive_timeout {
KeepaliveStatus::Timeout(d) => Some(d.as_secs()),
KeepaliveStatus::Infinite => Some(0),
KeepaliveStatus::Off => None,
}
}
pub fn set_keepalive_reuses_remaining(&mut self, remaining: Option<u32>) {
self.keepalive_reuses_remaining = remaining;
}
pub fn get_keepalive_reuses_remaining(&self) -> Option<u32> {
self.keepalive_reuses_remaining
}
pub fn will_keepalive(&self) -> bool {
!matches!(
(&self.keepalive_timeout, self.keepalive_reuses_remaining),
(KeepaliveStatus::Off, _) | (_, Some(0))
)
}
fn get_keepalive_values(&self) -> (Option<u64>, Option<usize>) {
(None, None)
}
fn ignore_info_resp(&self, status: u16) -> bool {
self.ignore_info_resp && status != 101 && !(status == 100 && self.is_expect_continue_req())
}
fn is_expect_continue_req(&self) -> bool {
match self.request_header.as_deref() {
Some(req) => is_expect_continue_req(req),
None => false,
}
}
fn is_connection_keepalive(&self) -> Option<bool> {
is_buf_keepalive(self.get_header(header::CONNECTION))
}
fn write_timeout(&self, buf_len: usize) -> Option<Duration> {
let Some(min_send_rate) = self.min_send_rate.filter(|r| *r > 0) else {
return self.write_timeout;
};
let ms = (buf_len.max(min_send_rate) as f64 / min_send_rate as f64) * 1000.0;
Some(Duration::from_millis(ms as u64))
}
pub fn respect_keepalive(&mut self) {
if let Some(keepalive) = self.is_connection_keepalive() {
if keepalive {
let (timeout, _max_use) = self.get_keepalive_values();
match timeout {
Some(d) => self.set_keepalive(Some(d)),
None => self.set_keepalive(Some(0)), }
} else {
self.set_keepalive(None);
}
} else if self.req_header().version == Version::HTTP_11 {
self.set_keepalive(Some(0)); } else {
self.set_keepalive(None); }
}
fn init_body_writer(&mut self, header: &ResponseHeader) {
use http::StatusCode;
if matches!(
header.status,
StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED
) || self.get_method() == Some(&Method::HEAD)
{
self.body_writer.init_content_length(0);
return;
}
if header.status.is_informational() && header.status != StatusCode::SWITCHING_PROTOCOLS {
return;
}
if self.is_upgrade(header) == Some(true) {
self.body_writer.init_close_delimited();
} else {
init_body_writer_comm(&mut self.body_writer, &header.headers);
}
}
pub async fn write_response_header_ref(&mut self, resp: &ResponseHeader) -> Result<()> {
self.write_response_header(Box::new(resp.clone())).await
}
async fn do_write_body(&mut self, buf: &[u8]) -> Result<Option<usize>> {
let written = self
.body_writer
.write_body(&mut self.underlying_stream, buf)
.await;
if let Ok(Some(num_bytes)) = written {
self.body_bytes_sent += num_bytes;
}
written
}
pub async fn write_body(&mut self, buf: &[u8]) -> Result<Option<usize>> {
match self.write_timeout(buf.len()) {
Some(t) => match timeout(t, self.do_write_body(buf)).await {
Ok(res) => res,
Err(_) => Error::e_explain(WriteTimedout, format!("writing body, timeout: {t:?}")),
},
None => self.do_write_body(buf).await,
}
}
async fn do_write_body_buf(&mut self) -> Result<Option<usize>> {
if self.body_write_buf.is_empty() {
return Ok(None);
}
let written = self
.body_writer
.write_body(&mut self.underlying_stream, &self.body_write_buf)
.await;
if let Ok(Some(num_bytes)) = written {
self.body_bytes_sent += num_bytes;
}
self.body_write_buf.clear();
written
}
async fn write_body_buf(&mut self) -> Result<Option<usize>> {
match self.write_timeout(self.body_write_buf.len()) {
Some(t) => match timeout(t, self.do_write_body_buf()).await {
Ok(res) => res,
Err(_) => Error::e_explain(WriteTimedout, format!("writing body, timeout: {t:?}")),
},
None => self.do_write_body_buf().await,
}
}
fn maybe_force_close_body_reader(&mut self) {
if self.upgraded && !self.body_reader.body_done() {
self.body_reader.init_content_length(0, b"");
}
}
pub async fn finish_body(&mut self) -> Result<Option<usize>> {
let res = self.body_writer.finish(&mut self.underlying_stream).await?;
self.underlying_stream
.flush()
.await
.or_err(WriteError, "flushing body")?;
trace!(
"finish body (response body writer), upgraded: {}",
self.upgraded
);
self.maybe_force_close_body_reader();
Ok(res)
}
pub fn body_bytes_sent(&self) -> usize {
self.body_bytes_sent
}
pub fn body_bytes_read(&self) -> usize {
self.body_bytes_read
}
fn is_chunked_encoding(&self) -> bool {
is_chunked_encoding_from_headers(&self.req_header().headers)
}
fn get_content_length(&self) -> Result<Option<usize>> {
buf_to_content_length(
self.get_header(header::CONTENT_LENGTH)
.map(|v| v.as_bytes()),
)
}
fn init_body_reader(&mut self) {
if self.body_reader.need_init() {
if let Some(buffer) = self.retry_buffer.as_mut() {
buffer.clear();
}
let preread_body = self.preread_body.as_ref().unwrap().get(&self.buf[..]);
if self.was_upgraded() {
self.body_reader.init_close_delimited(preread_body);
} else if self.is_chunked_encoding() {
self.body_reader.init_chunked(preread_body);
} else {
let cl = self.get_content_length().unwrap_or(None);
match cl {
Some(i) => {
self.body_reader.init_content_length(i, preread_body);
}
None => {
self.body_reader.init_content_length(0, preread_body);
}
}
}
}
}
pub fn retry_buffer_truncated(&self) -> bool {
self.retry_buffer
.as_ref()
.map_or_else(|| false, |r| r.is_truncated())
}
pub fn enable_retry_buffering(&mut self) {
if self.retry_buffer.is_none() {
self.retry_buffer = Some(FixedBuffer::new(BODY_BUF_LIMIT))
}
}
pub fn get_retry_buffer(&self) -> Option<Bytes> {
self.retry_buffer.as_ref().and_then(|b| {
if b.is_truncated() {
None
} else {
b.get_buffer()
}
})
}
fn get_body(&self, buf_ref: &BufRef) -> &[u8] {
self.body_reader.get_body(buf_ref)
}
pub async fn idle(&mut self) -> Result<usize> {
let mut buf: [u8; 1] = [0; 1];
self.underlying_stream
.read(&mut buf)
.await
.or_err(ReadError, "during HTTP idle state")
}
pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result<Option<Bytes>> {
if no_body_expected || self.is_body_done() {
let read = self.idle().await?;
if read == 0 {
Error::e_explain(
ConnectionClosed,
if self.response_written.is_none() {
"Prematurely before response header is sent"
} else {
"Prematurely before response body is complete"
},
)
} else {
Error::e_explain(ConnectError, "Sent data after end of body")
}
} else {
self.read_body_bytes().await
}
}
pub fn get_headers_raw_bytes(&self) -> Bytes {
self.raw_header.as_ref().unwrap().get_bytes(&self.buf)
}
pub async fn shutdown(&mut self) {
let _ = self.underlying_stream.shutdown().await;
}
pub fn set_server_keepalive(&mut self, keepalive: Option<u64>) {
if let Some(false) = self.is_connection_keepalive() {
self.set_keepalive(None);
} else {
self.set_keepalive(keepalive);
}
}
pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.read_timeout = timeout;
}
pub fn get_read_timeout(&self) -> Option<Duration> {
self.read_timeout
}
pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
self.write_timeout = timeout;
}
pub fn get_write_timeout(&self) -> Option<Duration> {
self.write_timeout
}
pub fn set_total_drain_timeout(&mut self, timeout: Option<Duration>) {
self.total_drain_timeout = timeout;
}
pub fn get_total_drain_timeout(&self) -> Option<Duration> {
self.total_drain_timeout
}
pub fn set_min_send_rate(&mut self, min_send_rate: Option<usize>) {
if let Some(rate) = min_send_rate.filter(|r| *r > 0) {
self.min_send_rate = Some(rate);
} else {
self.min_send_rate = None;
}
}
pub fn set_ignore_info_resp(&mut self, ignore: bool) {
self.ignore_info_resp = ignore;
}
pub fn set_close_on_response_before_downstream_finish(&mut self, close: bool) {
self.close_on_response_before_downstream_finish = close;
}
pub fn digest(&self) -> &Digest {
&self.digest
}
pub fn digest_mut(&mut self) -> &mut Digest {
&mut self.digest
}
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.digest()
.socket_digest
.as_ref()
.map(|d| d.peer_addr())?
}
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.digest()
.socket_digest
.as_ref()
.map(|d| d.local_addr())?
}
pub async fn reuse(mut self) -> Result<Option<Stream>> {
if !self.will_keepalive() {
debug!("HTTP shutdown connection");
self.shutdown().await;
Ok(None)
} else {
self.drain_request_body().await?;
if self.body_reader.has_bytes_overread() {
debug!("bytes overread on request, disallowing reuse");
Ok(None)
} else {
Ok(Some(self.underlying_stream))
}
}
}
pub async fn write_continue_response(&mut self) -> Result<()> {
if self.response_written.is_none() {
return self
.write_response_header(Box::new(ResponseHeader::build(100, Some(0)).unwrap()))
.await;
}
Ok(())
}
async fn write_non_empty_body(&mut self, data: Option<Bytes>, upgraded: bool) -> Result<()> {
if upgraded != self.upgraded {
if upgraded {
panic!("Unexpected UpgradedBody task received on un-upgraded downstream session");
} else {
panic!("Unexpected Body task received on upgraded downstream session");
}
}
let Some(d) = data else {
return Ok(());
};
if d.is_empty() {
return Ok(());
}
self.write_body(&d).await.map_err(|e| e.into_down())?;
Ok(())
}
async fn response_duplex(&mut self, task: HttpTask) -> Result<bool> {
let end_stream = match task {
HttpTask::Header(header, end_stream) => {
self.write_response_header(header)
.await
.map_err(|e| e.into_down())?;
end_stream
}
HttpTask::Body(data, end_stream) => {
self.write_non_empty_body(data, false).await?;
end_stream
}
HttpTask::UpgradedBody(data, end_stream) => {
self.write_non_empty_body(data, true).await?;
end_stream
}
HttpTask::Trailer(_) => true, HttpTask::Done => true,
HttpTask::Failed(e) => return Err(e),
};
if end_stream {
self.finish_body().await.map_err(|e| e.into_down())?;
}
Ok(end_stream || self.body_writer.finished())
}
fn buffer_body_data(&mut self, data: Option<Bytes>, upgraded: bool) {
if upgraded != self.upgraded {
if upgraded {
panic!("Unexpected Body task received on upgraded downstream session");
} else {
panic!("Unexpected UpgradedBody task received on un-upgraded downstream session");
}
}
let Some(d) = data else {
return;
};
if !d.is_empty() && !self.body_writer.finished() {
self.body_write_buf.put_slice(&d);
}
}
pub async fn response_duplex_vec(&mut self, mut tasks: Vec<HttpTask>) -> Result<bool> {
let n_tasks = tasks.len();
if n_tasks == 1 {
return self.response_duplex(tasks.pop().unwrap()).await;
}
let mut end_stream = false;
for task in tasks.into_iter() {
end_stream = match task {
HttpTask::Header(header, end_stream) => {
self.write_response_header(header)
.await
.map_err(|e| e.into_down())?;
end_stream
}
HttpTask::Body(data, end_stream) => {
self.buffer_body_data(data, false);
end_stream
}
HttpTask::UpgradedBody(data, end_stream) => {
self.buffer_body_data(data, true);
end_stream
}
HttpTask::Trailer(_) => true, HttpTask::Done => true,
HttpTask::Failed(e) => {
self.write_body_buf().await.map_err(|e| e.into_down())?;
self.underlying_stream
.flush()
.await
.or_err(WriteError, "flushing response")?;
return Err(e);
}
}
}
self.write_body_buf().await.map_err(|e| e.into_down())?;
if end_stream {
self.finish_body().await.map_err(|e| e.into_down())?;
}
Ok(end_stream || self.body_writer.finished())
}
pub fn stream(&self) -> &Stream {
&self.underlying_stream
}
pub fn into_inner(self) -> Stream {
self.underlying_stream
}
}
static REQUEST_LINE_REGEX: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^\w+ (?P<uri>.+) HTTP/\d(?:\.\d)?").unwrap());
const URI_ESC_CHARSET: &AsciiSet = &CONTROLS.add(b' ').add(b'<').add(b'>').add(b'"');
fn escape_illegal_request_line(buf: &BytesMut) -> Option<BytesMut> {
if let Some(captures) = REQUEST_LINE_REGEX.captures(buf) {
let uri = captures.name("uri")?;
let escaped_uri = percent_encode(uri.as_bytes(), URI_ESC_CHARSET);
let mut new_buf = BytesMut::with_capacity(buf.len() + 32);
new_buf.extend_from_slice(&buf[..uri.start()]);
for s in escaped_uri {
new_buf.extend_from_slice(s.as_bytes());
}
if new_buf.len() == uri.end() {
return None;
}
new_buf.extend_from_slice(&buf[uri.end()..]);
Some(new_buf)
} else {
None
}
}
#[inline]
fn parse_req_buffer<'buf>(
req: &mut httparse::Request<'_, 'buf>,
buf: &'buf [u8],
) -> HeaderParseState {
use httparse::Result;
#[cfg(feature = "patched_http1")]
fn parse<'buf>(req: &mut httparse::Request<'_, 'buf>, buf: &'buf [u8]) -> Result<usize> {
req.parse_unchecked(buf)
}
#[cfg(not(feature = "patched_http1"))]
fn parse<'buf>(req: &mut httparse::Request<'_, 'buf>, buf: &'buf [u8]) -> Result<usize> {
req.parse(buf)
}
let res = match parse(req, buf) {
Ok(s) => s,
Err(e) => {
return HeaderParseState::Invalid(e);
}
};
match res {
httparse::Status::Complete(s) => HeaderParseState::Complete(s),
_ => HeaderParseState::Partial,
}
}
#[inline]
fn http_resp_header_to_buf(
resp: &ResponseHeader,
buf: &mut BytesMut,
) -> std::result::Result<(), ()> {
let version = match resp.version {
Version::HTTP_09 => "HTTP/0.9 ",
Version::HTTP_10 => "HTTP/1.0 ",
Version::HTTP_11 => "HTTP/1.1 ",
_ => {
return Err(());
}
};
buf.put_slice(version.as_bytes());
let status = resp.status;
buf.put_slice(status.as_str().as_bytes());
buf.put_u8(b' ');
let reason = resp.get_reason_phrase();
if let Some(reason_buf) = reason {
buf.put_slice(reason_buf.as_bytes());
}
buf.put_slice(CRLF);
resp.header_to_h1_wire(buf);
buf.put_slice(CRLF);
Ok(())
}
#[cfg(test)]
mod tests_stream {
use super::*;
use crate::protocols::http::v1::body::{BodyMode, ParseState};
use http::StatusCode;
use pingora_error::ErrorType;
use rstest::rstest;
use std::str;
use tokio_test::io::Builder;
fn init_log() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn read_basic() {
init_log();
let input = b"GET / HTTP/1.1\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert_eq!(input.len(), res.unwrap().unwrap());
assert_eq!(0, http_stream.req_header().headers.len());
}
#[cfg(feature = "patched_http1")]
#[tokio::test]
async fn read_invalid_path() {
init_log();
let input = b"GET /\x01\xF0\x90\x80 HTTP/1.1\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert_eq!(input.len(), res.unwrap().unwrap());
assert_eq!(0, http_stream.req_header().headers.len());
assert_eq!(b"/\x01\xF0\x90\x80", http_stream.get_path());
}
#[tokio::test]
async fn read_2_buf() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\n\r\n";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert_eq!(input1.len() + input2.len(), res.unwrap().unwrap());
assert_eq!(
input1.len() + input2.len(),
http_stream.raw_header.as_ref().unwrap().len()
);
assert_eq!(1, http_stream.req_header().headers.len());
assert_eq!(Some(&Method::GET), http_stream.get_method());
assert_eq!(b"/", http_stream.get_path());
assert_eq!(Version::HTTP_11, http_stream.req_header().version);
assert_eq!(b"pingora.org", http_stream.get_header_bytes("Host"));
}
#[tokio::test]
async fn read_with_body_content_length() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nContent-Length: 3\r\n\r\n";
let input3 = b"abc";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, input3.as_slice());
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
assert_eq!(http_stream.body_bytes_read(), 3);
}
#[tokio::test]
#[should_panic(expected = "There is still data left to read.")]
async fn read_with_body_timeout() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nContent-Length: 3\r\n\r\n";
let input3 = b"abc";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.wait(Duration::from_secs(2))
.read(&input3[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_timeout = Some(Duration::from_secs(1));
http_stream.read_request().await.unwrap();
let res = http_stream.read_body_bytes().await;
assert_eq!(http_stream.body_bytes_read(), 0);
assert_eq!(res.unwrap_err().etype(), &ReadTimedout);
}
#[tokio::test]
async fn read_with_body_content_length_single_read() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nContent-Length: 3\r\n\r\nabc";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, b"abc".as_slice());
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
assert_eq!(http_stream.body_bytes_read(), 3);
}
#[tokio::test]
#[should_panic(expected = "There is still data left to read.")]
async fn read_with_body_http10() {
init_log();
let input1 = b"GET / HTTP/1.0\r\n";
let input2 = b"Host: pingora.org\r\n\r\n";
let input3 = b"a"; let input4 = b""; let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.read(&input4[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let res = http_stream.read_body_bytes().await.unwrap();
assert!(res.is_none());
assert_eq!(http_stream.body_bytes_read(), 0);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0));
}
#[tokio::test]
async fn read_with_body_http10_single_read() {
init_log();
let input1 = b"GET / HTTP/1.0\r\n";
let input2 = b"Host: pingora.org\r\n\r\na";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let res = http_stream.read_body_bytes().await.unwrap();
assert!(res.is_none());
assert_eq!(http_stream.body_bytes_read(), 0);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0));
assert_eq!(http_stream.body_reader.get_body_overread().unwrap(), b"a");
}
#[tokio::test]
async fn read_http11_default_no_body() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\n\r\n";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let res = http_stream.read_body_bytes().await.unwrap();
assert!(res.is_none());
assert_eq!(http_stream.body_bytes_read(), 0);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(0));
}
#[tokio::test]
async fn read_http10_with_content_length() {
init_log();
let input1 = b"POST / HTTP/1.0\r\n";
let input2 = b"Host: pingora.org\r\nContent-Length: 3\r\n\r\n";
let input3 = b"abc";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, input3.as_slice());
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
assert_eq!(http_stream.body_bytes_read(), 3);
}
#[tokio::test]
async fn read_with_body_chunked_0_incomplete() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n";
let input3 = b"0\r\n";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_chunked_encoding());
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, b"".as_slice());
let e = http_stream.read_body_bytes().await.unwrap_err();
assert_eq!(*e.etype(), ErrorType::ConnectionClosed);
assert_eq!(http_stream.body_reader.body_state, ParseState::Done(0));
}
#[tokio::test]
async fn read_with_body_chunked_0_extra() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n";
let input3 = b"0\r\n";
let input4 = b"abc";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.read(&input4[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_chunked_encoding());
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, b"".as_slice());
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, b"".as_slice());
let e = http_stream.read_body_bytes().await.unwrap_err();
assert_eq!(*e.etype(), ErrorType::ConnectionClosed);
assert_eq!(http_stream.body_reader.body_state, ParseState::Done(0));
}
#[tokio::test]
async fn read_with_body_chunked_single_read() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n1\r\na\r\n";
let input3 = b"0\r\n\r\n";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_chunked_encoding());
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, b"a".as_slice());
assert_eq!(
http_stream.body_reader.body_state,
ParseState::Chunked(1, 0, 0, 0)
);
let res = http_stream.read_body_bytes().await.unwrap();
assert!(res.is_none());
assert_eq!(http_stream.body_bytes_read(), 1);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(1));
}
#[tokio::test]
async fn read_with_body_chunked_single_read_extra() {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n1\r\na\r\n";
let input3 = b"0\r\n\r\nabc";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_chunked_encoding());
let res = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(res, b"a".as_slice());
assert_eq!(
http_stream.body_reader.body_state,
ParseState::Chunked(1, 0, 0, 0)
);
let res = http_stream.read_body_bytes().await.unwrap();
assert!(res.is_none());
assert_eq!(http_stream.body_bytes_read(), 1);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(1));
assert_eq!(http_stream.body_reader.get_body_overread().unwrap(), b"abc");
}
#[rstest]
#[case(None, None)]
#[case(Some("transfer-encoding"), None)]
#[case(Some("transfer-encoding"), Some("CONTENT-LENGTH"))]
#[case(Some("TRANSFER-ENCODING"), Some("CONTENT-LENGTH"))]
#[case(Some("TRANSFER-ENCODING"), None)]
#[case(None, Some("CONTENT-LENGTH"))]
#[case(Some("TRANSFER-ENCODING"), Some("content-length"))]
#[case(None, Some("content-length"))]
#[tokio::test]
async fn transfer_encoding_and_content_length_disallowed(
#[case] transfer_encoding_header: Option<&str>,
#[case] content_length_header: Option<&str>,
) {
init_log();
let input1 = b"GET / HTTP/1.1\r\n";
let mut input2 = "Host: pingora.org\r\n".to_owned();
if let Some(transfer_encoding) = transfer_encoding_header {
input2 += &format!("{transfer_encoding}: chunked\r\n");
}
if let Some(content_length) = content_length_header {
input2 += &format!("{content_length}: 4\r\n")
}
input2 += "\r\n3e\r\na\r\n";
let mock_io = Builder::new()
.read(&input1[..])
.read(input2.as_bytes())
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let _ = http_stream.read_request().await.unwrap();
match (content_length_header, transfer_encoding_header) {
(Some(_) | None, Some(_)) => {
assert!(http_stream.get_header(TRANSFER_ENCODING).is_some());
assert!(http_stream.get_header(CONTENT_LENGTH).is_none());
}
(Some(_), None) => {
assert!(http_stream.get_header(TRANSFER_ENCODING).is_none());
assert!(http_stream.get_header(CONTENT_LENGTH).is_some());
}
_ => {
assert!(http_stream.get_header(CONTENT_LENGTH).is_none());
assert!(http_stream.get_header(TRANSFER_ENCODING).is_none());
}
}
}
#[rstest]
#[case::negative("-1")]
#[case::not_a_number("abc")]
#[case::float("1.5")]
#[case::empty("")]
#[case::spaces(" ")]
#[case::mixed("123abc")]
#[tokio::test]
async fn validate_request_rejects_invalid_content_length(#[case] invalid_value: &str) {
init_log();
let input = format!(
"POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n",
invalid_value
);
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert!(res.is_err());
assert_eq!(res.unwrap_err().etype(), &InvalidHTTPHeader);
}
#[rstest]
#[case::valid_zero("0")]
#[case::valid_small("123")]
#[case::valid_large("999999")]
#[tokio::test]
async fn validate_request_accepts_valid_content_length(#[case] valid_value: &str) {
init_log();
let input = format!(
"POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n",
valid_value
);
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert!(res.is_ok());
}
#[tokio::test]
async fn validate_request_accepts_no_content_length() {
init_log();
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert!(res.is_ok());
}
#[tokio::test]
#[should_panic(expected = "There is still data left to read.")]
async fn read_invalid() {
let input1 = b"GET / HTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\n\r\n";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert_eq!(&InvalidHTTPHeader, res.unwrap_err().etype());
}
#[tokio::test]
async fn read_invalid_header_end() {
let input = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: 3\r\r\nConnection: keep-alive\r\n\r\nabc";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let res = http_stream.read_request().await;
assert_eq!(&InvalidHTTPHeader, res.unwrap_err().etype());
}
async fn build_upgrade_req(upgrade: &str, conn: &str) -> HttpSession {
let input = format!("GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: {upgrade}\r\nConnection: {conn}\r\n\r\n");
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
http_stream
}
#[tokio::test]
async fn read_upgrade_req() {
let input = b"GET / HTTP/1.0\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(!http_stream.is_upgrade_req());
let input = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nConnection: upgrade\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(!http_stream.is_upgrade_req());
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: WebSocket\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
assert!(build_upgrade_req("websocket", "Upgrade")
.await
.is_upgrade_req());
assert!(build_upgrade_req("WebSocket", "Upgrade")
.await
.is_upgrade_req());
}
const POST_CL_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nContent-Length: 10\r\n\r\n";
const POST_BODY_DATA: &[u8] = b"abcdefghij";
const POST_CHUNKED_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nTransfer-Encoding: chunked\r\n\r\n";
const POST_BODY_DATA_CHUNKED: &[u8] = b"3\r\nabc\r\n7\r\ndefghij\r\n0\r\n\r\n";
#[rstest]
#[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA)]
#[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA_CHUNKED)]
#[tokio::test]
async fn read_upgrade_req_with_body(
#[case] header: &[u8],
#[case] body: &[u8],
#[case] body_wire: &[u8],
) {
let ws_data = b"data";
let mock_io = Builder::new()
.read(header)
.read(body_wire)
.write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n")
.read(&ws_data[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
assert!(!http_stream.is_body_done());
let mut buf = vec![];
while let Some(b) = http_stream.read_body_bytes().await.unwrap() {
buf.put_slice(&b);
}
assert_eq!(buf, body);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10));
assert_eq!(http_stream.body_bytes_read(), 10);
assert!(http_stream.is_body_done());
let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(!http_stream.is_body_done());
let buf = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(buf, ws_data.as_slice());
assert!(!http_stream.is_body_done());
assert!(http_stream.read_body_bytes().await.unwrap().is_none());
assert!(http_stream.is_body_done());
}
#[rstest]
#[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA)]
#[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA_CHUNKED)]
#[tokio::test]
async fn read_upgrade_req_with_body_extra(
#[case] header: &[u8],
#[case] body: &[u8],
#[case] body_wire: &[u8],
) {
let ws_data = b"data";
let data_wire = [body_wire, ws_data.as_slice()].concat();
let mock_io = Builder::new()
.read(header)
.read(&data_wire[..])
.write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n")
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
assert!(!http_stream.is_body_done());
let mut buf = vec![];
while let Some(b) = http_stream.read_body_bytes().await.unwrap() {
buf.put_slice(&b);
}
assert_eq!(buf, body);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10));
assert_eq!(http_stream.body_bytes_read(), 10);
assert!(http_stream.is_body_done());
let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(!http_stream.is_body_done());
let buf = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(buf, ws_data.as_slice());
assert!(!http_stream.is_body_done());
assert!(http_stream.read_body_bytes().await.unwrap().is_none());
assert!(http_stream.is_body_done());
}
#[rstest]
#[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA)]
#[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA, POST_BODY_DATA_CHUNKED)]
#[tokio::test]
async fn read_upgrade_req_with_preread_body(
#[case] header: &[u8],
#[case] body: &[u8],
#[case] body_wire: &[u8],
) {
let ws_data = b"data";
let data_wire = [header, body_wire, ws_data.as_slice()].concat();
let mock_io = Builder::new()
.read(&data_wire[..])
.write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n")
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
assert!(!http_stream.is_body_done());
let mut buf = vec![];
while let Some(b) = http_stream.read_body_bytes().await.unwrap() {
buf.put_slice(&b);
}
assert_eq!(buf, body);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10));
assert_eq!(http_stream.body_bytes_read(), 10);
assert!(http_stream.is_body_done());
let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(!http_stream.is_body_done());
let buf = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(buf, ws_data.as_slice());
assert!(!http_stream.is_body_done());
assert!(http_stream.read_body_bytes().await.unwrap().is_none());
assert!(http_stream.is_body_done());
}
#[rstest]
#[case::content_length(POST_CL_UPGRADE_REQ, POST_BODY_DATA)]
#[case::chunked(POST_CHUNKED_UPGRADE_REQ, POST_BODY_DATA_CHUNKED)]
#[tokio::test]
async fn read_upgrade_req_with_preread_body_after_101(
#[case] header: &[u8],
#[case] body_wire: &[u8],
) {
let ws_data = b"data";
let data_wire = [header, body_wire, ws_data.as_slice()].concat();
let mock_io = Builder::new()
.read(&data_wire[..])
.write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n")
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
assert!(!http_stream.is_body_done());
let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(!http_stream.is_body_done());
let mut buf = vec![];
while let Some(b) = http_stream.read_body_bytes().await.unwrap() {
buf.put_slice(&b);
}
let expected_body = [body_wire, ws_data.as_slice()].concat();
assert_eq!(buf, expected_body.as_bytes());
assert_eq!(http_stream.body_bytes_read(), expected_body.len());
assert!(http_stream.is_body_done());
}
#[tokio::test]
async fn read_upgrade_req_with_1xx_response() {
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\n\r\n";
let mock_io = Builder::new()
.read(&input[..])
.write(b"HTTP/1.1 100 Continue\r\n\r\n")
.write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n")
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
let mut response = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(http_stream.is_body_done());
let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(!http_stream.is_body_done());
assert!(http_stream.read_body_bytes().await.unwrap().is_none());
assert!(http_stream.is_body_done());
}
#[tokio::test]
async fn test_upgrade_without_content_length_with_ws_data() {
let request = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\n\r\n";
let ws_data = b"websocket data";
let mock_io = Builder::new()
.read(request)
.write(b"HTTP/1.1 101 Switching Protocols\r\n\r\n")
.read(ws_data) .build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_upgrade_req());
http_stream.set_close_on_response_before_downstream_finish(false);
let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert_eq!(
http_stream.body_reader.body_state,
ParseState::UntilClose(0),
"Body reader should be in UntilClose mode after 101 for upgraded connections"
);
let mut buf = vec![];
while let Some(b) = http_stream.read_body_bytes().await.unwrap() {
buf.put_slice(&b);
}
assert_eq!(buf, ws_data, "Expected to read websocket data after 101");
}
#[tokio::test]
async fn set_server_keepalive() {
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nConnection: close\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Off);
http_stream.set_server_keepalive(Some(60));
assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Off);
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nConnection: keep-alive\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Infinite);
http_stream.set_server_keepalive(Some(60));
assert_eq!(
http_stream.keepalive_timeout,
KeepaliveStatus::Timeout(Duration::from_secs(60))
);
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Infinite);
http_stream.set_server_keepalive(Some(60));
assert_eq!(
http_stream.keepalive_timeout,
KeepaliveStatus::Timeout(Duration::from_secs(60))
);
}
#[tokio::test]
async fn write() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let write_expected = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(read_wire).write(write_expected).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
}
#[tokio::test]
async fn write_custom_reason() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let write_expected = b"HTTP/1.1 200 Just Fine\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(read_wire).write(write_expected).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.set_reason_phrase(Some("Just Fine")).unwrap();
new_response.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
}
#[tokio::test]
async fn write_informational() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let write_expected = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(read_wire).write(write_expected).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
http_stream
.write_response_header_ref(&response_100)
.await
.unwrap();
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
response_200.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
}
#[tokio::test]
async fn write_informational_ignored() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let write_expected = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(read_wire).write(write_expected).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.ignore_info_resp = true;
http_stream.read_request().await.unwrap();
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
http_stream
.write_response_header_ref(&response_100)
.await
.unwrap();
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
response_200.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
}
#[tokio::test]
async fn write_informational_100_not_ignored_if_expect_continue() {
let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n";
let output = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).write(output).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
http_stream.ignore_info_resp = true;
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
http_stream
.write_response_header_ref(&response_100)
.await
.unwrap();
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
response_200.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
}
#[tokio::test]
async fn write_informational_1xx_ignored_if_expect_continue() {
let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n";
let output = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).write(output).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
http_stream.ignore_info_resp = true;
let response_102 = ResponseHeader::build(StatusCode::PROCESSING, None).unwrap();
http_stream
.write_response_header_ref(&response_102)
.await
.unwrap();
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
response_200.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
}
#[tokio::test]
async fn write_101_switching_protocol() {
let read_wire = b"GET / HTTP/1.1\r\nUpgrade: websocket\r\n\r\n";
let wire = b"HTTP/1.1 101 Switching Protocols\r\nFoo: Bar\r\n\r\n";
let wire_body = b"nPAYLOAD";
let mock_io = Builder::new()
.read(read_wire)
.write(wire)
.write(wire_body)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let mut response_101 =
ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response_101.append_header("Foo", "Bar").unwrap();
http_stream
.write_response_header_ref(&response_101)
.await
.unwrap();
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0));
let n = http_stream.write_body(wire_body).await.unwrap().unwrap();
assert_eq!(wire_body.len(), n);
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(n));
let response_502 = ResponseHeader::build(StatusCode::BAD_GATEWAY, None).unwrap();
http_stream
.write_response_header_ref(&response_502)
.await
.unwrap();
}
#[tokio::test]
async fn write_body_cl() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let wire_header = b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\n";
let wire_body = b"a";
let mock_io = Builder::new()
.read(read_wire)
.write(wire_header)
.write(wire_body)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Content-Length", "1").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(1, 0)
);
let n = http_stream.write_body(wire_body).await.unwrap().unwrap();
assert_eq!(wire_body.len(), n);
let n = http_stream.finish_body().await.unwrap().unwrap();
assert_eq!(wire_body.len(), n);
}
#[tokio::test]
async fn write_body_http10() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let wire_header = b"HTTP/1.1 200 OK\r\n\r\n";
let wire_body = b"a";
let mock_io = Builder::new()
.read(read_wire)
.write(wire_header)
.write(wire_body)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0));
let n = http_stream.write_body(wire_body).await.unwrap().unwrap();
assert_eq!(wire_body.len(), n);
let n = http_stream.finish_body().await.unwrap().unwrap();
assert_eq!(wire_body.len(), n);
}
#[tokio::test]
async fn write_body_chunk() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let wire_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n";
let wire_body = b"1\r\na\r\n";
let wire_end = b"0\r\n\r\n";
let mock_io = Builder::new()
.read(read_wire)
.write(wire_header)
.write(wire_body)
.write(wire_end)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response
.append_header("Transfer-Encoding", "chunked")
.unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ChunkedEncoding(0)
);
let n = http_stream.write_body(b"a").await.unwrap().unwrap();
assert_eq!(b"a".len(), n);
let n = http_stream.finish_body().await.unwrap().unwrap();
assert_eq!(b"a".len(), n);
}
#[tokio::test]
async fn read_with_illegal() {
init_log();
let input1 = b"GET /a?q=b c HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nContent-Length: 3\r\n\r\n";
let input3 = b"abc";
let mock_io = Builder::new()
.read(&input1[..])
.read(&input2[..])
.read(&input3[..])
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert_eq!(http_stream.get_path(), &b"/a?q=b%20c"[..]);
let res = http_stream.read_body().await.unwrap().unwrap();
assert_eq!(res, BufRef::new(0, 3));
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
assert_eq!(input3, http_stream.get_body(&res));
}
#[test]
fn escape_illegal() {
init_log();
let input = BytesMut::from(
&b"GET /a?q=<\"b c\"> HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: 3\r\n\r\n"[..],
);
let output = escape_illegal_request_line(&input).unwrap();
assert_eq!(
&output,
&b"GET /a?q=%3C%22b%20c%22%3E HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: 3\r\n\r\n"[..]
);
let input = BytesMut::from(
&b"GET /a:\"bc\" HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: 3\r\n\r\n"[..],
);
let output = escape_illegal_request_line(&input).unwrap();
assert_eq!(
&output,
&b"GET /a:%22bc%22 HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: 3\r\n\r\n"[..]
);
let input =
BytesMut::from(&b"GET HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: 3\r\n\r\n"[..]);
assert!(escape_illegal_request_line(&input).is_none());
}
#[tokio::test]
async fn test_write_body_buf() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let write_expected = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(read_wire).write(write_expected).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
let written = http_stream.write_body_buf().await.unwrap();
assert!(written.is_none());
}
#[tokio::test]
#[should_panic(expected = "There is still data left to write.")]
async fn test_write_body_buf_write_timeout() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let wire1 = b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\n";
let wire2 = b"abc";
let mock_io = Builder::new()
.read(read_wire)
.write(wire1)
.wait(Duration::from_millis(500))
.write(wire2)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
http_stream.write_timeout = Some(Duration::from_millis(100));
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Content-Length", "3").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
http_stream.body_write_buf = BytesMut::from(&b"abc"[..]);
let res = http_stream.write_body_buf().await;
assert_eq!(res.unwrap_err().etype(), &WriteTimedout);
}
#[tokio::test]
async fn test_write_continue_resp() {
let read_wire = b"GET / HTTP/1.1\r\n\r\n";
let write_expected = b"HTTP/1.1 100 Continue\r\n\r\n";
let mock_io = Builder::new().read(read_wire).write(write_expected).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
http_stream.write_continue_response().await.unwrap();
}
#[test]
fn test_get_write_timeout() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
let expected = Duration::from_secs(5);
http_stream.set_write_timeout(Some(expected));
assert_eq!(Some(expected), http_stream.write_timeout(50));
}
#[test]
fn test_get_write_timeout_none() {
let http_stream = HttpSession::new(Box::new(Builder::new().build()));
assert!(http_stream.write_timeout(50).is_none());
}
#[test]
fn test_get_write_timeout_min_send_rate_zero() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
http_stream.set_min_send_rate(Some(0));
assert!(http_stream.write_timeout(50).is_none());
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
http_stream.set_min_send_rate(None);
assert!(http_stream.write_timeout(50).is_none());
}
#[test]
fn test_get_write_timeout_min_send_rate_overrides_write_timeout() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
let expected = Duration::from_millis(29800);
http_stream.set_write_timeout(Some(Duration::from_secs(60)));
http_stream.set_min_send_rate(Some(5000));
assert_eq!(Some(expected), http_stream.write_timeout(149000));
}
#[test]
fn test_get_write_timeout_min_send_rate_max_zero_buf() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
let expected = Duration::from_secs(1);
http_stream.set_min_send_rate(Some(1));
assert_eq!(Some(expected), http_stream.write_timeout(0));
}
#[tokio::test]
async fn test_te_and_cl_disables_keepalive() {
let input = b"POST / HTTP/1.1\r\n\
Host: pingora.org\r\n\
Transfer-Encoding: chunked\r\n\
Content-Length: 10\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert_eq!(http_stream.keepalive_timeout, KeepaliveStatus::Off);
assert!(!http_stream
.req_header()
.headers
.contains_key(CONTENT_LENGTH));
assert!(http_stream
.req_header()
.headers
.contains_key(TRANSFER_ENCODING));
}
#[tokio::test]
async fn test_http10_request_with_transfer_encoding_rejected() {
let input = b"POST / HTTP/1.0\r\n\
Host: pingora.org\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let result = http_stream.read_request().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.etype(), &InvalidHTTPHeader);
assert!(err.to_string().contains("Transfer-Encoding"));
}
#[tokio::test]
async fn test_http10_request_without_transfer_encoding_accepted() {
let input = b"POST / HTTP/1.0\r\n\
Host: pingora.org\r\n\
Content-Length: 5\r\n\
\r\n\
hello";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let result = http_stream.read_request().await;
assert!(result.is_ok());
assert_eq!(http_stream.req_header().version, http::Version::HTTP_10);
}
#[tokio::test]
async fn test_http11_request_with_transfer_encoding_accepted() {
let input = b"POST / HTTP/1.1\r\n\
Host: pingora.org\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
let result = http_stream.read_request().await;
assert!(result.is_ok());
assert_eq!(http_stream.req_header().version, http::Version::HTTP_11);
}
#[tokio::test]
async fn test_request_multiple_transfer_encoding_headers() {
init_log();
let input = b"POST / HTTP/1.1\r\n\
Host: pingora.org\r\n\
Transfer-Encoding: gzip\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
5\r\n\
hello\r\n\
0\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_chunked_encoding());
let body = http_stream.read_body_bytes().await.unwrap();
assert_eq!(body.unwrap().as_ref(), b"hello");
}
#[tokio::test]
async fn test_request_multiple_te_headers_chunked_not_last() {
init_log();
let input = b"POST / HTTP/1.1\r\n\
Host: pingora.org\r\n\
Transfer-Encoding: chunked\r\n\
Transfer-Encoding: identity\r\n\
Content-Length: 5\r\n\
\r\n";
let mock_io = Builder::new().read(&input[..]).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap_err();
}
#[tokio::test]
async fn test_no_more_reuses_explicitly_disables_reuse() {
init_log();
let wire_req = b"GET /test HTTP/1.1\r\n\r\n";
let wire_header = b"HTTP/1.1 200 OK\r\n\r\n";
let mock_io = Builder::new()
.read(&wire_req[..])
.write(wire_header)
.build();
let mut http_session = HttpSession::new(Box::new(mock_io));
http_session.set_keepalive_reuses_remaining(Some(0));
http_session.read_request().await.unwrap();
let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
http_session.update_resp_headers = false;
http_session
.write_response_header(Box::new(new_response))
.await
.unwrap();
assert_eq!(http_session.body_writer.body_mode, BodyMode::UntilClose(0));
http_session.finish_body().await.unwrap().unwrap();
http_session.set_keepalive(Some(100));
let reused = http_session.reuse().await.unwrap();
assert!(reused.is_none());
}
#[tokio::test]
async fn test_close_delimited_response_explicitly_disables_reuse() {
init_log();
let wire_req = b"GET /test HTTP/1.1\r\n\r\n";
let wire_header = b"HTTP/1.1 200 OK\r\n\r\n";
let mock_io = Builder::new()
.read(&wire_req[..])
.write(wire_header)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header(Box::new(new_response))
.await
.unwrap();
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0));
http_stream.finish_body().await.unwrap().unwrap();
let reused = http_stream.reuse().await.unwrap();
assert!(reused.is_none());
}
}
#[cfg(test)]
mod test_sync {
use super::*;
use http::StatusCode;
use log::{debug, error};
use std::str;
fn init_log() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[test]
fn test_response_to_wire() {
init_log();
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Foo", "Bar").unwrap();
let mut wire = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE);
http_resp_header_to_buf(&new_response, &mut wire).unwrap();
debug!("{}", str::from_utf8(wire.as_ref()).unwrap());
let mut headers = [httparse::EMPTY_HEADER; 128];
let mut resp = httparse::Response::new(&mut headers);
let result = resp.parse(wire.as_ref());
match result {
Ok(_) => {}
Err(e) => error!("{:?}", e),
}
assert!(result.unwrap().is_complete());
assert_eq!(b"Foo", headers[0].name.as_bytes());
assert_eq!(b"Bar", headers[0].value);
}
}
#[cfg(test)]
mod test_timeouts {
use super::*;
use std::future::IntoFuture;
use tokio_test::io::{Builder, Mock};
const TEST_MAX_WAIT_FOR_READ: Duration = Duration::from_secs(3);
const TEST_FOREVER_DURATION: Duration = Duration::from_secs(600);
const TEST_READ_TIMEOUT: Duration = Duration::from_secs(1);
#[derive(Debug)]
struct ReadBlockedForeverError;
fn mocked_blocking_headers_forever_stream() -> Box<Mock> {
Box::new(Builder::new().wait(TEST_FOREVER_DURATION).build())
}
fn mocked_blocking_body_forever_stream() -> Box<Mock> {
let http1 = b"GET / HTTP/1.1\r\n";
let http2 = b"Host: pingora.example\r\nContent-Length: 3\r\n\r\n";
Box::new(
Builder::new()
.read(&http1[..])
.read(&http2[..])
.wait(TEST_FOREVER_DURATION)
.build(),
)
}
async fn test_read_with_tokio_timeout<F, T>(
read_future: F,
) -> Result<Result<T, Box<Error>>, ReadBlockedForeverError>
where
F: IntoFuture<Output = Result<T, Box<Error>>>,
{
let read_result = tokio::time::timeout(TEST_MAX_WAIT_FOR_READ, read_future).await;
read_result.map_err(|_| ReadBlockedForeverError)
}
#[tokio::test]
async fn test_read_http_request_headers_timeout_for_read_request() {
let mut http_stream = HttpSession::new(mocked_blocking_headers_forever_stream());
http_stream.read_timeout = None;
let res = test_read_with_tokio_timeout(http_stream.read_request()).await;
assert!(res.is_err());
let mut http_stream = HttpSession::new(mocked_blocking_headers_forever_stream());
http_stream.read_timeout = Some(TEST_READ_TIMEOUT);
let res = test_read_with_tokio_timeout(http_stream.read_request()).await;
assert!(res.is_ok());
assert_eq!(res.unwrap().unwrap_err().etype(), &ReadTimedout);
}
#[tokio::test]
async fn test_read_http_body_timeout_for_read_body_bytes() {
let mut http_stream = HttpSession::new(mocked_blocking_body_forever_stream());
http_stream.read_timeout = None;
http_stream.read_request().await.unwrap();
let res = test_read_with_tokio_timeout(http_stream.read_body_bytes()).await;
assert!(res.is_err());
let mut http_stream = HttpSession::new(mocked_blocking_body_forever_stream());
http_stream.read_timeout = Some(TEST_READ_TIMEOUT);
http_stream.read_request().await.unwrap();
let res = test_read_with_tokio_timeout(http_stream.read_body_bytes()).await;
assert!(res.is_ok());
assert_eq!(res.unwrap().unwrap_err().etype(), &ReadTimedout);
}
}
#[cfg(test)]
mod test_overread {
use super::*;
use rstest::rstest;
use tokio_test::io::Builder;
fn init_log() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[rstest]
#[case(0, None, true, true)] #[case(0, None, false, true)] #[case(0, Some(&b"extra_data_here"[..]), true, false)] #[case(0, Some(&b"extra_data_here"[..]), false, false)] #[case(5, None, true, true)] #[case(5, None, false, true)] #[case(5, Some(&b"extra"[..]), true, false)] #[case(5, Some(&b"extra"[..]), false, false)] #[tokio::test]
async fn test_reuse_with_preread_body_overread(
#[case] content_length: usize,
#[case] extra_bytes: Option<&[u8]>,
#[case] read_body: bool,
#[case] expect_reuse: bool,
) {
init_log();
let body = b"hello";
let mut request_data = Vec::new();
request_data.extend_from_slice(b"GET / HTTP/1.1\r\n");
request_data.extend_from_slice(
format!("Host: pingora.org\r\nContent-Length: {content_length}\r\n\r\n",).as_bytes(),
);
if content_length > 0 {
request_data.extend_from_slice(&body[..content_length]);
}
if let Some(extra) = extra_bytes {
request_data.extend_from_slice(extra);
}
let mock_io = Builder::new().read(&request_data).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
if read_body {
let result = http_stream.read_body_bytes().await.unwrap();
if content_length == 0 {
assert!(
result.is_none(),
"Body should be empty for Content-Length: 0"
);
} else {
let body_result = result.unwrap();
assert_eq!(body_result.as_ref(), &body[..content_length]);
}
assert_eq!(http_stream.body_bytes_read(), content_length);
}
let reused = http_stream.reuse().await.unwrap();
assert_eq!(reused.is_some(), expect_reuse);
}
#[rstest]
#[case(true)]
#[case(false)]
#[tokio::test]
async fn test_reuse_with_chunked_body_overread(#[case] read_body: bool) {
init_log();
let headers = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nTransfer-Encoding: chunked\r\n\r\n";
let body_and_extra = b"5\r\nhello\r\n0\r\n\r\nextra";
let mock_io = Builder::new().read(headers).read(body_and_extra).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
assert!(http_stream.is_chunked_encoding());
if read_body {
let result = http_stream.read_body_bytes().await.unwrap();
assert_eq!(result.unwrap().as_ref(), b"hello");
let result = http_stream.read_body_bytes().await.unwrap();
assert!(result.is_none());
assert_eq!(http_stream.body_bytes_read(), 5);
}
let reused = http_stream.reuse().await.unwrap();
assert!(reused.is_none());
}
}