use bytes::Bytes;
use http::HeaderValue;
use http::{header, header::AsHeaderName, HeaderMap, Method};
use log::{debug, trace, warn};
use pingora_error::{Error, ErrorType::*, OkOrErr, Result};
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_timeout::timeout;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use super::body::{BodyReader, BodyWriter};
use crate::protocols::http::{
body_buffer::FixedBuffer,
server::Session as GenericHttpSession,
subrequest::dummy::DummyIO,
v1::common::{header_value_content_length, is_chunked_encoding_from_headers, BODY_BUF_LIMIT},
v1::server::HttpSession as SessionV1,
HttpTask,
};
use crate::protocols::{Digest, SocketAddr};
pub struct HttpSession {
tx: Option<mpsc::Sender<HttpTask>>,
rx: Option<mpsc::Receiver<HttpTask>>,
v1_inner: Box<SessionV1>,
proxy_error: Option<oneshot::Sender<Box<Error>>>, read_req_header: bool,
response_written: Option<ResponseHeader>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
total_drain_timeout: Option<Duration>,
body_bytes_sent: usize,
body_bytes_read: usize,
retry_buffer: Option<FixedBuffer>,
body_reader: BodyReader,
body_writer: BodyWriter,
upgraded: bool,
clear_request_body_headers: bool,
digest: Option<Box<Digest>>,
}
pub struct SubrequestHandle {
pub tx: mpsc::Sender<HttpTask>,
pub rx: mpsc::Receiver<HttpTask>,
pub subreq_wants_body: oneshot::Receiver<()>,
pub subreq_proxy_error: oneshot::Receiver<Box<Error>>,
}
impl SubrequestHandle {
pub fn drain_tasks(mut self) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let _tx = self.tx; while self.rx.recv().await.is_some() {}
trace!("subrequest dropped");
})
}
}
impl HttpSession {
pub fn new_from_session(session: &GenericHttpSession) -> (Self, SubrequestHandle) {
let v1_inner = SessionV1::new(Box::new(DummyIO::new(&session.to_h1_raw())));
let digest = session.digest().cloned();
const CHANNEL_BUFFER_SIZE: usize = 4;
let (downstream_tx, downstream_rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let (upstream_tx, upstream_rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let (wants_body_tx, wants_body_rx) = oneshot::channel();
let (proxy_error_tx, proxy_error_rx) = oneshot::channel();
(
HttpSession {
v1_inner: Box::new(v1_inner),
tx: Some(upstream_tx),
rx: Some(downstream_rx),
proxy_error: Some(proxy_error_tx),
body_reader: BodyReader::new(Some(wants_body_tx)),
body_writer: BodyWriter::new(),
read_req_header: false,
response_written: None,
read_timeout: None,
write_timeout: None,
total_drain_timeout: None,
body_bytes_sent: 0,
body_bytes_read: 0,
retry_buffer: None,
upgraded: false,
clear_request_body_headers: false,
digest: digest.map(Box::new),
},
SubrequestHandle {
tx: downstream_tx,
rx: upstream_rx,
subreq_wants_body: wants_body_rx,
subreq_proxy_error: proxy_error_rx,
},
)
}
pub async fn read_request(&mut self) -> Result<Option<usize>> {
let res = self.v1_inner.read_request().await?;
if res.is_none() {
return Error::e_explain(InternalError, "no session request header provided");
}
self.read_req_header = true;
if self.clear_request_body_headers {
self.clear_request_body_headers();
}
Ok(res)
}
pub fn validate_request(&self) -> Result<()> {
self.v1_inner.validate_request()
}
pub fn req_header(&self) -> &RequestHeader {
self.v1_inner.req_header()
}
pub fn req_header_mut(&mut self) -> &mut RequestHeader {
self.v1_inner.req_header_mut()
}
pub fn get_header(&self, name: impl AsHeaderName) -> Option<&HeaderValue> {
self.v1_inner.get_header(name)
}
pub(super) fn get_method(&self) -> Option<&http::Method> {
self.v1_inner.get_method()
}
pub(super) fn get_path(&self) -> &[u8] {
self.v1_inner.get_path()
}
pub(super) fn get_host(&self) -> &[u8] {
self.v1_inner.get_host()
}
pub fn request_summary(&self) -> String {
format!(
"{} {}, Host: {} (subrequest)",
self.get_method().map_or("-", |r| r.as_str()),
String::from_utf8_lossy(self.get_path()),
String::from_utf8_lossy(self.get_host())
)
}
pub fn is_upgrade_req(&self) -> bool {
self.v1_inner.is_upgrade_req()
}
pub fn get_header_bytes(&self, name: impl AsHeaderName) -> &[u8] {
self.v1_inner.get_header_bytes(name)
}
pub async fn read_body_bytes(&mut self) -> Result<Option<Bytes>> {
let read = self.read_body().await?;
Ok(read.inspect(|b| {
self.body_bytes_read += b.len();
if let Some(buffer) = self.retry_buffer.as_mut() {
buffer.write_to_buffer(b);
}
}))
}
async fn do_read_body(&mut self) -> Result<Option<Bytes>> {
self.init_body_reader();
self.body_reader
.read_body(self.rx.as_mut().expect("rx valid before shutdown"))
.await
}
async fn read_body(&mut self) -> Result<Option<Bytes>> {
match self.read_timeout {
Some(t) => match timeout(t, self.do_read_body()).await {
Ok(res) => res,
Err(_) => Error::e_explain(
ReadTimedout,
format!("reading body, timeout: {t:?} (subrequest)"),
),
},
None => self.do_read_body().await,
}
}
async fn do_drain_request_body(&mut self) -> Result<()> {
loop {
match self.read_body_bytes().await {
Ok(Some(_)) => { }
Ok(None) => return Ok(()), Err(e) => return Err(e),
}
}
}
pub async fn drain_request_body(&mut self) -> Result<()> {
if self.is_body_done() {
return Ok(());
}
match self.total_drain_timeout {
Some(t) => match timeout(t, self.do_drain_request_body()).await {
Ok(res) => res,
Err(_) => Error::e_explain(
ReadTimedout,
format!("draining body, timeout: {t:?} (subrequest)"),
),
},
None => self.do_drain_request_body().await,
}
}
pub fn is_body_done(&mut self) -> bool {
self.init_body_reader();
self.body_reader.body_done()
}
pub fn is_body_empty(&mut self) -> bool {
self.init_body_reader();
self.body_reader.body_empty()
}
pub async fn write_response_header(&mut self, header: Box<ResponseHeader>) -> Result<()> {
if let Some(resp) = self.response_written.as_ref() {
if !resp.status.is_informational() || self.upgraded {
warn!("Respond header is already sent, cannot send again (subrequest)");
return Ok(());
}
}
if header.status == 101 || !header.status.is_informational() {
if let Some(upgrade_ok) = self.is_upgrade(&header) {
if upgrade_ok {
debug!("ok upgrade handshake");
self.upgraded = true;
if self.body_reader.need_init() {
self.init_body_reader();
} else {
self.body_reader.convert_to_close_delimited();
}
} else {
debug!("bad upgrade handshake!");
}
}
self.init_body_writer(&header);
}
debug!("send response header (subrequest)");
match self
.tx
.as_mut()
.expect("tx valid before shutdown")
.send(HttpTask::Header(header.clone(), false))
.await
{
Ok(()) => {
self.response_written = Some(*header);
Ok(())
}
Err(e) => Error::e_because(WriteError, "writing response header", e),
}
}
pub fn response_written(&self) -> Option<&ResponseHeader> {
self.response_written.as_ref()
}
pub fn is_upgrade(&self, header: &ResponseHeader) -> Option<bool> {
self.v1_inner.is_upgrade(header)
}
pub fn was_upgraded(&self) -> bool {
self.upgraded
}
fn init_body_writer(&mut self, header: &ResponseHeader) {
use http::StatusCode;
if matches!(
header.status,
StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED
) || self.get_method() == Some(&Method::HEAD)
{
self.body_writer.init_content_length(0);
return;
}
if header.status.is_informational() && header.status != StatusCode::SWITCHING_PROTOCOLS {
return;
}
if self.is_upgrade(header) == Some(true) {
self.body_writer.init_close_delimited();
} else if is_chunked_encoding_from_headers(&header.headers) {
self.body_writer.init_close_delimited();
} else {
let content_length =
header_value_content_length(header.headers.get(http::header::CONTENT_LENGTH));
match content_length {
Some(length) => {
self.body_writer.init_content_length(length);
}
None => {
self.body_writer.init_close_delimited();
}
}
}
}
pub async fn write_response_header_ref(&mut self, resp: &ResponseHeader) -> Result<()> {
self.write_response_header(Box::new(resp.clone())).await
}
async fn do_write_body(&mut self, buf: Bytes) -> Result<Option<usize>> {
let written = self
.body_writer
.write_body(self.tx.as_mut().expect("tx valid before shutdown"), buf)
.await;
if let Ok(Some(num_bytes)) = written {
self.body_bytes_sent += num_bytes;
}
written
}
pub async fn write_body(&mut self, buf: Bytes) -> Result<Option<usize>> {
match self.write_timeout {
Some(t) => match timeout(t, self.do_write_body(buf)).await {
Ok(res) => res,
Err(_) => Error::e_explain(WriteTimedout, format!("writing body, timeout: {t:?}")),
},
None => self.do_write_body(buf).await,
}
}
fn maybe_force_close_body_reader(&mut self) {
if self.upgraded && !self.body_reader.body_done() {
self.body_reader.init_content_length(0);
}
}
pub async fn finish(&mut self) -> Result<Option<usize>> {
let res = self
.body_writer
.finish(self.tx.as_mut().expect("tx valid before shutdown"))
.await?;
self.maybe_force_close_body_reader();
Ok(res)
}
pub fn on_proxy_failure(&mut self, e: Box<Error>) {
if let Some(sender) = self.proxy_error.take() {
let _ = sender.send(e);
}
}
pub fn body_bytes_sent(&self) -> usize {
self.body_bytes_sent
}
pub fn body_bytes_read(&self) -> usize {
self.body_bytes_read
}
fn is_chunked_encoding(&self) -> bool {
is_chunked_encoding_from_headers(&self.req_header().headers)
}
pub fn clear_request_body_headers(&mut self) {
self.clear_request_body_headers = true;
if self.read_req_header {
let req = self.v1_inner.req_header_mut();
req.remove_header(&header::CONTENT_LENGTH);
req.remove_header(&header::TRANSFER_ENCODING);
req.remove_header(&header::CONTENT_TYPE);
req.remove_header(&header::CONTENT_ENCODING);
}
}
fn init_body_reader(&mut self) {
if self.body_reader.need_init() {
if let Some(buffer) = self.retry_buffer.as_mut() {
buffer.clear();
}
if self.was_upgraded() {
self.body_reader.init_close_delimited();
} else if self.is_chunked_encoding() {
self.body_reader.init_close_delimited();
} else {
let cl = header_value_content_length(self.get_header(header::CONTENT_LENGTH));
match cl {
Some(i) => {
self.body_reader.init_content_length(i);
}
None => {
self.body_reader.init_content_length(0);
}
}
}
}
}
pub fn retry_buffer_truncated(&self) -> bool {
self.retry_buffer
.as_ref()
.map_or_else(|| false, |r| r.is_truncated())
}
pub fn enable_retry_buffering(&mut self) {
if self.retry_buffer.is_none() {
self.retry_buffer = Some(FixedBuffer::new(BODY_BUF_LIMIT))
}
}
pub fn get_retry_buffer(&self) -> Option<Bytes> {
self.retry_buffer.as_ref().and_then(|b| {
if b.is_truncated() {
None
} else {
b.get_buffer()
}
})
}
pub async fn idle(&mut self) -> Result<HttpTask> {
let rx = self.rx.as_mut().expect("rx valid before shutdown");
let mut task = rx
.recv()
.await
.or_err(ReadError, "during HTTP idle state")?;
while matches!(&task, HttpTask::Done)
|| matches!(&task, HttpTask::Body(b, _) if b.as_ref().is_none_or(|b| b.is_empty()))
{
task = rx
.recv()
.await
.or_err(ReadError, "during HTTP idle state")?;
}
Ok(task)
}
pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result<Option<Bytes>> {
if no_body_expected || self.is_body_done() {
let read_task = self.idle().await?;
Error::e_explain(
ConnectError,
format!("Sent unexpected task {read_task:?} after end of body (subrequest)"),
)
} else {
self.read_body_bytes().await
}
}
pub fn get_headers_raw_bytes(&self) -> Bytes {
self.v1_inner.get_headers_raw_bytes()
}
pub fn shutdown(&mut self) {
drop(self.tx.take());
drop(self.rx.take());
}
pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.read_timeout = timeout;
}
pub fn get_read_timeout(&self) -> Option<Duration> {
self.read_timeout
}
pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
self.write_timeout = timeout;
}
pub fn get_write_timeout(&self) -> Option<Duration> {
self.write_timeout
}
pub fn set_total_drain_timeout(&mut self, timeout: Option<Duration>) {
self.total_drain_timeout = timeout;
}
pub fn get_total_drain_timeout(&self) -> Option<Duration> {
self.total_drain_timeout
}
pub fn digest(&self) -> Option<&Digest> {
self.digest.as_deref()
}
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
self.digest.as_deref_mut()
}
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.digest()
.and_then(|d| d.socket_digest.as_ref())
.map(|d| d.peer_addr())?
}
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.digest()
.and_then(|d| d.socket_digest.as_ref())
.map(|d| d.local_addr())?
}
pub async fn write_continue_response(&mut self) -> Result<()> {
if self.response_written.is_none() {
return self
.write_response_header(Box::new(ResponseHeader::build(100, Some(0)).unwrap()))
.await;
}
Ok(())
}
async fn write_non_empty_body(&mut self, data: Option<Bytes>, upgraded: bool) -> Result<()> {
if upgraded != self.upgraded {
if upgraded {
panic!("Unexpected UpgradedBody task received on un-upgraded downstream session (subrequest)");
} else {
panic!("Unexpected Body task received on upgraded downstream session (subrequest)");
}
}
let Some(d) = data else {
return Ok(());
};
if d.is_empty() {
return Ok(());
}
self.write_body(d).await.map_err(|e| e.into_down())?;
Ok(())
}
async fn response_duplex(&mut self, task: HttpTask) -> Result<bool> {
let end_stream = match task {
HttpTask::Header(header, end_stream) => {
self.write_response_header(header)
.await
.map_err(|e| e.into_down())?;
end_stream
}
HttpTask::Body(data, end_stream) => {
self.write_non_empty_body(data, false).await?;
end_stream
}
HttpTask::UpgradedBody(data, end_stream) => {
self.write_non_empty_body(data, true).await?;
end_stream
}
HttpTask::Trailer(trailers) => {
self.write_trailers(trailers).await?;
true
}
HttpTask::Done => true,
HttpTask::Failed(e) => return Err(e),
};
if end_stream {
self.finish().await.map_err(|e| e.into_down())?;
}
Ok(end_stream || self.body_writer.finished())
}
pub async fn response_duplex_vec(&mut self, mut tasks: Vec<HttpTask>) -> Result<bool> {
let n_tasks = tasks.len();
if n_tasks == 1 {
return self.response_duplex(tasks.pop().unwrap()).await;
}
let mut end_stream = false;
for task in tasks.into_iter() {
end_stream = match task {
HttpTask::Header(header, end_stream) => {
self.write_response_header(header)
.await
.map_err(|e| e.into_down())?;
end_stream
}
HttpTask::Body(data, end_stream) => {
self.write_non_empty_body(data, false).await?;
end_stream
}
HttpTask::UpgradedBody(data, end_stream) => {
self.write_non_empty_body(data, true).await?;
end_stream
}
HttpTask::Done => {
true
}
HttpTask::Trailer(trailers) => {
self.write_trailers(trailers).await?;
true
}
HttpTask::Failed(e) => {
return Err(e);
}
} || end_stream; }
if end_stream {
self.finish().await.map_err(|e| e.into_down())?;
}
Ok(end_stream || self.body_writer.finished())
}
pub async fn write_trailers(&mut self, trailers: Option<Box<HeaderMap>>) -> Result<()> {
self.body_writer
.write_trailers(
self.tx.as_mut().expect("tx valid before shutdown"),
trailers,
)
.await
}
}
#[cfg(test)]
mod tests_stream {
use super::*;
use crate::protocols::http::subrequest::body::{BodyMode, ParseState};
use bytes::BufMut;
use http::StatusCode;
use rstest::rstest;
use std::str;
use tokio_test::io::Builder;
fn init_log() {
let _ = env_logger::builder().is_test(true).try_init();
}
async fn session_from_input(input: &[u8]) -> (HttpSession, SubrequestHandle) {
let mock_io = Builder::new().read(input).build();
let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream);
http_stream.read_request().await.unwrap();
(http_stream, handle)
}
async fn build_upgrade_req(upgrade: &str, conn: &str) -> (HttpSession, SubrequestHandle) {
let input = format!("GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: {upgrade}\r\nConnection: {conn}\r\n\r\n");
session_from_input(input.as_bytes()).await
}
async fn build_req() -> (HttpSession, SubrequestHandle) {
let input = "GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n".to_string();
session_from_input(input.as_bytes()).await
}
#[tokio::test]
async fn read_basic() {
init_log();
let input = b"GET / HTTP/1.1\r\n\r\n";
let (http_stream, _handle) = session_from_input(input).await;
assert_eq!(0, http_stream.req_header().headers.len());
assert_eq!(Method::GET, http_stream.req_header().method);
assert_eq!(b"/", http_stream.req_header().uri.path().as_bytes());
}
#[tokio::test]
async fn read_upgrade_req() {
let input = b"GET / HTTP/1.0\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\n\r\n";
let (http_stream, _handle) = session_from_input(input).await;
assert!(!http_stream.is_upgrade_req());
let input = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\n\r\n";
let (http_stream, _handle) = session_from_input(input).await;
assert!(http_stream.is_upgrade_req());
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nConnection: upgrade\r\n\r\n";
let (http_stream, _handle) = session_from_input(input).await;
assert!(!http_stream.is_upgrade_req());
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: WebSocket\r\n\r\n";
let (http_stream, _handle) = session_from_input(input).await;
assert!(http_stream.is_upgrade_req());
let (http_stream, _handle) = build_upgrade_req("websocket", "Upgrade").await;
assert!(http_stream.is_upgrade_req());
let (http_stream, _handle) = build_upgrade_req("WebSocket", "Upgrade").await;
assert!(http_stream.is_upgrade_req());
}
#[tokio::test]
async fn read_upgrade_req_with_1xx_response() {
let (mut http_stream, _handle) = build_upgrade_req("websocket", "upgrade").await;
assert!(http_stream.is_upgrade_req());
let mut response = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(http_stream.is_body_done());
}
#[tokio::test]
async fn write() {
let (mut http_stream, mut handle) = build_req().await;
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Foo", "Bar").unwrap();
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
match handle.rx.try_recv().unwrap() {
HttpTask::Header(header, end) => {
assert_eq!(header.status, StatusCode::OK);
assert_eq!(header.headers["foo"], "Bar");
assert!(!end);
}
t => panic!("unexpected task {t:?}"),
}
}
#[tokio::test]
async fn write_informational() {
let (mut http_stream, mut handle) = build_req().await;
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
http_stream
.write_response_header_ref(&response_100)
.await
.unwrap();
match handle.rx.try_recv().unwrap() {
HttpTask::Header(header, end) => {
assert_eq!(header.status, StatusCode::CONTINUE);
assert!(!end);
}
t => panic!("unexpected task {t:?}"),
}
let response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
match handle.rx.try_recv().unwrap() {
HttpTask::Header(header, end) => {
assert_eq!(header.status, StatusCode::OK);
assert!(!end);
}
t => panic!("unexpected task {t:?}"),
}
}
#[tokio::test]
async fn write_101_switching_protocol() {
let (mut http_stream, mut handle) = build_upgrade_req("WebSocket", "Upgrade").await;
let mut response_101 =
ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response_101.append_header("Foo", "Bar").unwrap();
http_stream
.write_response_header_ref(&response_101)
.await
.unwrap();
match handle.rx.try_recv().unwrap() {
HttpTask::Header(header, end) => {
assert_eq!(header.status, StatusCode::SWITCHING_PROTOCOLS);
assert!(!end);
}
t => panic!("unexpected task {t:?}"),
}
assert!(http_stream.upgraded);
let wire_body = Bytes::from(&b"PAYLOAD"[..]);
let n = http_stream
.write_body(wire_body.clone())
.await
.unwrap()
.unwrap();
assert_eq!(wire_body.len(), n);
let response_502 = ResponseHeader::build(StatusCode::BAD_GATEWAY, None).unwrap();
http_stream
.write_response_header_ref(&response_502)
.await
.unwrap();
match handle.rx.try_recv().unwrap() {
HttpTask::Body(body, _end) => {
assert_eq!(body.unwrap().len(), n);
}
t => panic!("unexpected task {t:?}"),
}
assert_eq!(
handle.rx.try_recv().unwrap_err(),
mpsc::error::TryRecvError::Empty
);
}
#[tokio::test]
async fn write_body_cl() {
let (mut http_stream, _handle) = build_req().await;
let wire_body = Bytes::from(&b"a"[..]);
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Content-Length", "1").unwrap();
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
assert_eq!(
http_stream.body_writer.body_mode,
BodyMode::ContentLength(1, 0)
);
let n = http_stream
.write_body(wire_body.clone())
.await
.unwrap()
.unwrap();
assert_eq!(wire_body.len(), n);
let n = http_stream.finish().await.unwrap().unwrap();
assert_eq!(wire_body.len(), n);
}
#[tokio::test]
async fn write_body_until_close() {
let (mut http_stream, _handle) = build_req().await;
let new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
assert_eq!(http_stream.body_writer.body_mode, BodyMode::UntilClose(0));
let wire_body = Bytes::from(&b"PAYLOAD"[..]);
let n = http_stream
.write_body(wire_body.clone())
.await
.unwrap()
.unwrap();
assert_eq!(wire_body.len(), n);
let n = http_stream.finish().await.unwrap().unwrap();
assert_eq!(wire_body.len(), n);
}
#[tokio::test]
async fn read_with_illegal() {
init_log();
let input1 = b"GET /a?q=b c HTTP/1.1\r\n";
let input2 = b"Host: pingora.org\r\nContent-Length: 3\r\n\r\n";
let input3 = b"abc";
let mock_io = Builder::new().read(&input1[..]).read(&input2[..]).build();
let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream);
http_stream.read_request().await.unwrap();
handle
.tx
.send(HttpTask::Body(Some(Bytes::from(&input3[..])), false))
.await
.unwrap();
assert_eq!(http_stream.get_path(), &b"/a?q=b%20c"[..]);
let res = http_stream.read_body().await.unwrap().unwrap();
assert_eq!(res, &input3[..]);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(3));
}
#[tokio::test]
async fn test_write_body_write_timeout() {
let (mut http_stream, _handle) = build_req().await;
http_stream.write_timeout = Some(Duration::from_millis(100));
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Content-Length", "10").unwrap();
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
let body_write_buf = Bytes::from(&b"abc"[..]);
http_stream
.write_body(body_write_buf.clone())
.await
.unwrap();
http_stream
.write_body(body_write_buf.clone())
.await
.unwrap();
http_stream.write_body(body_write_buf).await.unwrap();
let last_body = Bytes::from(&b"a"[..]);
let res = http_stream.write_body(last_body).await;
assert_eq!(res.unwrap_err().etype(), &WriteTimedout);
}
#[tokio::test]
async fn test_write_continue_resp() {
let (mut http_stream, mut handle) = build_req().await;
http_stream.write_continue_response().await.unwrap();
match handle.rx.try_recv().unwrap() {
HttpTask::Header(header, end) => {
assert_eq!(header.status, StatusCode::CONTINUE);
assert!(!end);
}
t => panic!("unexpected task {t:?}"),
}
}
async fn session_from_input_no_validate(input: &[u8]) -> (HttpSession, SubrequestHandle) {
let mock_io = Builder::new().read(input).build();
let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let (http_stream, handle) = HttpSession::new_from_session(&http_stream);
(http_stream, handle)
}
#[rstest]
#[case::negative("-1")]
#[case::not_a_number("abc")]
#[case::float("1.5")]
#[case::empty("")]
#[case::spaces(" ")]
#[case::mixed("123abc")]
#[tokio::test]
async fn validate_request_rejects_invalid_content_length(#[case] invalid_value: &str) {
init_log();
let input = format!(
"POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n",
invalid_value
);
let mock_io = Builder::new().read(input.as_bytes()).build();
let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io));
let res = http_stream.read_request().await;
assert!(res.is_err());
assert_eq!(
res.unwrap_err().etype(),
&pingora_error::ErrorType::InvalidHTTPHeader
);
}
#[rstest]
#[case::valid_zero("0")]
#[case::valid_small("123")]
#[case::valid_large("999999")]
#[tokio::test]
async fn validate_request_accepts_valid_content_length(#[case] valid_value: &str) {
init_log();
let input = format!(
"POST / HTTP/1.1\r\nHost: pingora.org\r\nContent-Length: {}\r\n\r\n",
valid_value
);
let (mut http_stream, _handle) = session_from_input_no_validate(input.as_bytes()).await;
let res = http_stream.read_request().await;
assert!(res.is_ok());
}
#[tokio::test]
async fn validate_request_accepts_no_content_length() {
init_log();
let input = b"GET / HTTP/1.1\r\nHost: pingora.org\r\n\r\n";
let (mut http_stream, _handle) = session_from_input_no_validate(input).await;
let res = http_stream.read_request().await;
assert!(res.is_ok());
}
const POST_CL_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nContent-Length: 10\r\n\r\n";
const POST_CHUNKED_UPGRADE_REQ: &[u8] = b"POST / HTTP/1.1\r\nHost: pingora.org\r\nUpgrade: websocket\r\nConnection: upgrade\r\nTransfer-Encoding: chunked\r\n\r\n";
const POST_BODY_DATA: &[u8] = b"abcdefghij";
async fn build_upgrade_req_with_body(header: &[u8]) -> (HttpSession, SubrequestHandle) {
let mock_io = Builder::new().read(header).build();
let mut http_stream = GenericHttpSession::new_http1(Box::new(mock_io));
http_stream.read_request().await.unwrap();
let (mut http_stream, handle) = HttpSession::new_from_session(&http_stream);
http_stream.read_request().await.unwrap();
(http_stream, handle)
}
#[rstest]
#[case::content_length(POST_CL_UPGRADE_REQ)]
#[case::chunked(POST_CHUNKED_UPGRADE_REQ)]
#[tokio::test]
async fn read_upgrade_req_with_body(#[case] header: &[u8]) {
init_log();
let (mut http_stream, handle) = build_upgrade_req_with_body(header).await;
assert!(http_stream.is_upgrade_req());
assert!(!http_stream.is_body_done());
handle
.tx
.send(HttpTask::Body(Some(Bytes::from(POST_BODY_DATA)), true))
.await
.unwrap();
let mut buf = vec![];
while let Some(b) = http_stream.read_body_bytes().await.unwrap() {
buf.put_slice(&b);
}
assert_eq!(buf, POST_BODY_DATA);
assert_eq!(http_stream.body_reader.body_state, ParseState::Complete(10));
assert_eq!(http_stream.body_bytes_read(), 10);
assert!(http_stream.is_body_done());
let mut response = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, None).unwrap();
response.set_version(http::Version::HTTP_11);
http_stream
.write_response_header(Box::new(response))
.await
.unwrap();
assert!(!http_stream.is_body_done());
let ws_data = b"data";
handle
.tx
.send(HttpTask::Body(Some(Bytes::from(&ws_data[..])), false))
.await
.unwrap();
let buf = http_stream.read_body_bytes().await.unwrap().unwrap();
assert_eq!(buf, ws_data.as_slice());
assert!(!http_stream.is_body_done());
drop(handle.tx);
assert!(http_stream.read_body_bytes().await.unwrap().is_none());
assert!(http_stream.is_body_done());
}
}