use bytes::Bytes;
use futures::Future;
use h2::server;
use h2::server::SendResponse;
use h2::{RecvStream, SendStream};
use http::header::HeaderName;
use http::uri::PathAndQuery;
use http::{header, HeaderMap, Response};
use log::{debug, warn};
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_timeout::timeout;
use std::sync::Arc;
use std::task::ready;
use std::time::Duration;
use crate::protocols::http::body_buffer::FixedBuffer;
use crate::protocols::http::date::get_cached_date;
use crate::protocols::http::v1::client::http_req_header_to_wire;
use crate::protocols::http::HttpTask;
use crate::protocols::{Digest, SocketAddr, Stream};
use crate::{Error, ErrorType, OrErr, Result};
const BODY_BUF_LIMIT: usize = 1024 * 64;
type H2Connection<S> = server::Connection<S, Bytes>;
pub use h2::server::Builder as H2Options;
pub async fn handshake(io: Stream, options: Option<H2Options>) -> Result<H2Connection<Stream>> {
let options = options.unwrap_or_default();
let res = options.handshake(io).await;
match res {
Ok(connection) => {
debug!("H2 handshake done.");
Ok(connection)
}
Err(e) => Error::e_because(
ErrorType::HandshakeError,
"while h2 handshaking with client",
e,
),
}
}
use futures::task::Context;
use futures::task::Poll;
use std::pin::Pin;
pub struct Idle<'a>(&'a mut HttpSession);
impl Future for Idle<'_> {
type Output = Result<h2::Reason>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(body_writer) = self.0.send_response_body.as_mut() {
body_writer.poll_reset(cx)
} else {
self.0.send_response.poll_reset(cx)
}
.map_err(|e| Error::because(ErrorType::H2Error, "downstream error while idling", e))
}
}
pub struct HttpSession {
request_header: RequestHeader,
request_body_reader: RecvStream,
send_response: SendResponse<Bytes>,
send_response_body: Option<SendStream<Bytes>>,
response_written: Option<Box<ResponseHeader>>,
ended: bool,
body_read: usize,
body_sent: usize,
retry_buffer: Option<FixedBuffer>,
digest: Arc<Digest>,
pub write_timeout: Option<Duration>,
total_drain_timeout: Option<Duration>,
}
impl HttpSession {
pub async fn from_h2_conn(
conn: &mut H2Connection<Stream>,
digest: Arc<Digest>,
) -> Result<Option<Self>> {
let res = conn.accept().await.transpose().or_err(
ErrorType::H2Error,
"while accepting new downstream requests",
)?;
Ok(res.map(|(req, send_response)| {
let (request_header, request_body_reader) = req.into_parts();
HttpSession {
request_header: request_header.into(),
request_body_reader,
send_response,
send_response_body: None,
response_written: None,
ended: false,
body_read: 0,
body_sent: 0,
retry_buffer: None,
digest,
write_timeout: None,
total_drain_timeout: None,
}
}))
}
pub fn req_header(&self) -> &RequestHeader {
&self.request_header
}
pub fn req_header_mut(&mut self) -> &mut RequestHeader {
&mut self.request_header
}
pub async fn read_body_bytes(&mut self) -> Result<Option<Bytes>> {
let data = self.request_body_reader.data().await.transpose().or_err(
ErrorType::ReadError,
"while reading downstream request body",
)?;
if let Some(data) = data.as_ref() {
self.body_read += data.len();
if let Some(buffer) = self.retry_buffer.as_mut() {
buffer.write_to_buffer(data);
}
let _ = self
.request_body_reader
.flow_control()
.release_capacity(data.len());
}
Ok(data)
}
#[doc(hidden)]
pub fn poll_read_body_bytes(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, h2::Error>>> {
let data = match ready!(self.request_body_reader.poll_data(cx)).transpose() {
Ok(data) => data,
Err(err) => return Poll::Ready(Some(Err(err))),
};
if let Some(data) = data {
self.body_read += data.len();
self.request_body_reader
.flow_control()
.release_capacity(data.len())?;
return Poll::Ready(Some(Ok(data)));
}
Poll::Ready(None)
}
async fn do_drain_request_body(&mut self) -> Result<()> {
loop {
match self.read_body_bytes().await {
Ok(Some(_)) => { }
Ok(None) => return Ok(()), Err(e) => return Err(e),
}
}
}
pub async fn drain_request_body(&mut self) -> Result<()> {
if self.is_body_done() {
return Ok(());
}
match self.total_drain_timeout {
Some(t) => match timeout(t, self.do_drain_request_body()).await {
Ok(res) => res,
Err(_) => Error::e_explain(
ErrorType::ReadTimedout,
format!("draining body, timeout: {t:?}"),
),
},
None => self.do_drain_request_body().await,
}
}
pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
self.write_timeout = timeout;
}
pub fn get_write_timeout(&self) -> Option<Duration> {
self.write_timeout
}
pub fn set_total_drain_timeout(&mut self, timeout: Option<Duration>) {
self.total_drain_timeout = timeout;
}
pub fn get_total_drain_timeout(&self) -> Option<Duration> {
self.total_drain_timeout
}
pub fn write_response_header(
&mut self,
mut header: Box<ResponseHeader>,
end: bool,
) -> Result<()> {
if self.ended {
return Ok(());
}
if header.status.is_informational() {
debug!("ignoring informational headers");
return Ok(());
}
if self.response_written.as_ref().is_some() {
warn!("Response header is already sent, cannot send again");
return Ok(());
}
header.insert_header(header::DATE, get_cached_date())?;
header.remove_header(&header::TRANSFER_ENCODING);
header.remove_header(&header::CONNECTION);
header.remove_header(&header::UPGRADE);
header.remove_header(&HeaderName::from_static("keep-alive"));
header.remove_header(&HeaderName::from_static("proxy-connection"));
let resp = Response::from_parts(header.as_owned_parts(), ());
let body_writer = self.send_response.send_response(resp, end).or_err(
ErrorType::WriteError,
"while writing h2 response to downstream",
)?;
self.response_written = Some(header);
self.send_response_body = Some(body_writer);
self.ended = self.ended || end;
Ok(())
}
pub async fn write_body(&mut self, data: Bytes, end: bool) -> Result<()> {
match self.write_timeout {
Some(t) => match timeout(t, self.do_write_body(data, end)).await {
Ok(res) => res,
Err(_) => Error::e_explain(
ErrorType::WriteTimedout,
format!("writing body, timeout: {t:?}"),
),
},
None => self.do_write_body(data, end).await,
}
}
async fn do_write_body(&mut self, data: Bytes, end: bool) -> Result<()> {
if self.ended {
warn!("Try to write body after end of stream, dropping the extra data");
return Ok(());
}
let Some(writer) = self.send_response_body.as_mut() else {
return Err(Error::explain(
ErrorType::H2Error,
"try to send body before header is sent",
));
};
let data_len = data.len();
super::write_body(writer, data, end, self.write_timeout)
.await
.map_err(|e| e.into_down())?;
self.body_sent += data_len;
self.ended = self.ended || end;
Ok(())
}
pub fn write_trailers(&mut self, trailers: HeaderMap) -> Result<()> {
if self.ended {
warn!("Tried to write trailers after end of stream, dropping them");
return Ok(());
}
let Some(writer) = self.send_response_body.as_mut() else {
return Err(Error::explain(
ErrorType::H2Error,
"try to send trailers before header is sent",
));
};
writer.send_trailers(trailers).or_err(
ErrorType::WriteError,
"while writing h2 response trailers to downstream",
)?;
self.ended = true;
Ok(())
}
pub fn write_response_header_ref(&mut self, header: &ResponseHeader, end: bool) -> Result<()> {
self.write_response_header(Box::new(header.clone()), end)
}
pub fn finish(&mut self) -> Result<()> {
if self.ended {
return Ok(());
}
if let Some(writer) = self.send_response_body.as_mut() {
writer.send_data("".into(), true).or_err(
ErrorType::WriteError,
"while writing h2 response body to downstream",
)?;
self.ended = true;
};
Ok(())
}
pub async fn response_duplex_vec(&mut self, tasks: Vec<HttpTask>) -> Result<bool> {
let mut end_stream = false;
for task in tasks.into_iter() {
end_stream = match task {
HttpTask::Header(header, end) => {
self.write_response_header(header, end)
.map_err(|e| e.into_down())?;
end
}
HttpTask::Body(data, end) => match data {
Some(d) => {
if !d.is_empty() {
self.write_body(d, end).await.map_err(|e| e.into_down())?;
}
end
}
None => end,
},
HttpTask::UpgradedBody(..) => {
return Error::e_explain(
ErrorType::InternalError,
"upgraded body on h2 server session",
);
}
HttpTask::Trailer(Some(trailers)) => {
self.write_trailers(*trailers)?;
true
}
HttpTask::Trailer(None) => true,
HttpTask::Done => true,
HttpTask::Failed(e) => {
return Err(e);
}
} || end_stream }
if end_stream {
self.finish().map_err(|e| e.into_down())?;
}
Ok(end_stream)
}
pub fn request_summary(&self) -> String {
format!(
"{} {}, Host: {}:{}",
self.request_header.method,
self.request_header
.uri
.path_and_query()
.map(PathAndQuery::as_str)
.unwrap_or_default(),
self.request_header.uri.host().unwrap_or_default(),
self.req_header()
.uri
.port()
.as_ref()
.map(|port| port.as_str())
.unwrap_or_default()
)
}
pub fn response_written(&self) -> Option<&ResponseHeader> {
self.response_written.as_deref()
}
pub fn shutdown(&mut self) {
if !self.ended {
self.send_response.send_reset(h2::Reason::INTERNAL_ERROR);
}
}
#[doc(hidden)]
pub fn take_response_body_writer(&mut self) -> Option<SendStream<Bytes>> {
self.send_response_body.take()
}
pub fn pseudo_raw_h1_request_header(&self) -> Bytes {
let buf = http_req_header_to_wire(&self.request_header).unwrap(); buf.freeze()
}
pub fn is_body_done(&self) -> bool {
self.is_body_empty() || self.request_body_reader.is_end_stream()
}
pub fn is_body_empty(&self) -> bool {
self.body_read == 0
&& (self.request_body_reader.is_end_stream()
|| self
.request_header
.headers
.get(header::CONTENT_LENGTH)
.is_some_and(|cl| cl.as_bytes() == b"0"))
}
pub fn retry_buffer_truncated(&self) -> bool {
self.retry_buffer
.as_ref()
.map_or_else(|| false, |r| r.is_truncated())
}
pub fn enable_retry_buffering(&mut self) {
if self.retry_buffer.is_none() {
self.retry_buffer = Some(FixedBuffer::new(BODY_BUF_LIMIT))
}
}
pub fn get_retry_buffer(&self) -> Option<Bytes> {
self.retry_buffer.as_ref().and_then(|b| {
if b.is_truncated() {
None
} else {
b.get_buffer()
}
})
}
pub fn idle(&mut self) -> Idle<'_> {
Idle(self)
}
pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result<Option<Bytes>> {
if no_body_expected || self.is_body_done() {
let reason = self.idle().await?;
Error::e_explain(
ErrorType::H2Error,
format!("Client closed H2, reason: {reason}"),
)
} else {
self.read_body_bytes().await
}
}
pub fn body_bytes_sent(&self) -> usize {
self.body_sent
}
pub fn body_bytes_read(&self) -> usize {
self.body_read
}
pub fn digest(&self) -> Option<&Digest> {
Some(&self.digest)
}
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
Arc::get_mut(&mut self.digest)
}
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.digest.socket_digest.as_ref().map(|d| d.local_addr())?
}
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.digest.socket_digest.as_ref().map(|d| d.peer_addr())?
}
}
#[cfg(test)]
mod test {
use super::*;
use http::{HeaderValue, Method, Request};
use tokio::io::duplex;
#[tokio::test]
async fn test_server_handshake_accept_request() {
let (client, server) = duplex(65536);
let client_body = "test client body";
let server_body = "test server body";
let mut expected_trailers = HeaderMap::new();
expected_trailers.insert("test", HeaderValue::from_static("trailers"));
let trailers = expected_trailers.clone();
let mut handles = vec![];
handles.push(tokio::spawn(async move {
let (h2, connection) = h2::client::handshake(client).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com/")
.body(())
.unwrap();
let (response, mut req_body) = h2.send_request(request, false).unwrap();
req_body.reserve_capacity(client_body.len());
req_body.send_data(client_body.into(), true).unwrap();
let (head, mut body) = response.await.unwrap().into_parts();
assert_eq!(head.status, 200);
let data = body.data().await.unwrap().unwrap();
assert_eq!(data, server_body);
let resp_trailers = body.trailers().await.unwrap().unwrap();
assert_eq!(resp_trailers, expected_trailers);
}));
let mut connection = handshake(Box::new(server), None).await.unwrap();
let digest = Arc::new(Digest::default());
while let Some(mut http) = HttpSession::from_h2_conn(&mut connection, digest.clone())
.await
.unwrap()
{
let trailers = trailers.clone();
handles.push(tokio::spawn(async move {
let req = http.req_header();
assert_eq!(req.method, Method::GET);
assert_eq!(req.uri, "https://www.example.com/");
http.enable_retry_buffering();
assert!(!http.is_body_empty());
assert!(!http.is_body_done());
let body = http.read_body_or_idle(false).await.unwrap().unwrap();
assert_eq!(body, client_body);
assert!(http.is_body_done());
assert_eq!(http.body_bytes_read(), 16);
let retry_body = http.get_retry_buffer().unwrap();
assert_eq!(retry_body, client_body);
tokio::select! {
_ = http.idle() => {panic!("downstream should be idling")},
_= tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {}
}
let response_header = Box::new(ResponseHeader::build(200, None).unwrap());
assert!(http
.write_response_header(response_header.clone(), false)
.is_ok());
assert!(http.write_response_header(response_header, false).is_ok());
tokio::select! {
_ = http.read_body_or_idle(false) => {panic!("downstream should be idling")},
_= tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {}
}
http.write_body(server_body.into(), false).await.unwrap();
assert_eq!(http.body_bytes_sent(), 16);
http.write_trailers(trailers).unwrap();
http.finish().unwrap();
}));
}
for handle in handles {
assert!(handle.await.is_ok());
}
}
#[tokio::test]
async fn test_req_content_length_eq_0_and_no_header_eos() {
let (client, server) = duplex(65536);
let server_body = "test server body";
let mut handles = vec![];
handles.push(tokio::spawn(async move {
let (h2, connection) = h2::client::handshake(client).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.header("content-length", "0") .body(())
.unwrap();
let (response, mut req_body) = h2.send_request(request, false).unwrap();
let (head, mut body) = response.await.unwrap().into_parts();
assert_eq!(head.status, 200);
let data = body.data().await.unwrap().unwrap();
assert_eq!(data, server_body);
req_body.send_data("".into(), true).unwrap(); }));
let mut connection = handshake(Box::new(server), None).await.unwrap();
let digest = Arc::new(Digest::default());
while let Some(mut http) = HttpSession::from_h2_conn(&mut connection, digest.clone())
.await
.unwrap()
{
handles.push(tokio::spawn(async move {
let req = http.req_header();
assert_eq!(req.method, Method::POST);
assert_eq!(req.uri, "https://www.example.com/");
http.enable_retry_buffering();
assert!(http.is_body_empty());
assert!(http.is_body_done());
let retry_body = http.get_retry_buffer();
assert!(retry_body.is_none());
let response_header = Box::new(ResponseHeader::build(200, None).unwrap());
assert!(http
.write_response_header(response_header.clone(), false)
.is_ok());
http.write_body(server_body.into(), false).await.unwrap();
assert_eq!(http.body_bytes_sent(), 16);
assert!(http.read_body_or_idle(http.is_body_done()).await.is_err());
}));
}
for handle in handles {
assert!(handle.await.is_ok());
}
}
#[tokio::test]
async fn test_req_header_no_eos_empty_data_with_eos() {
let (client, server) = duplex(65536);
let server_body = "test server body";
let mut handles = vec![];
handles.push(tokio::spawn(async move {
let (h2, connection) = h2::client::handshake(client).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.body(())
.unwrap();
let (response, mut req_body) = h2.send_request(request, false).unwrap();
let (head, mut body) = response.await.unwrap().into_parts();
assert_eq!(head.status, 200);
let data = body.data().await.unwrap().unwrap();
assert_eq!(data, server_body);
req_body.send_data("".into(), true).unwrap(); }));
let mut connection = handshake(Box::new(server), None).await.unwrap();
let digest = Arc::new(Digest::default());
while let Some(mut http) = HttpSession::from_h2_conn(&mut connection, digest.clone())
.await
.unwrap()
{
handles.push(tokio::spawn(async move {
let req = http.req_header();
assert_eq!(req.method, Method::POST);
assert_eq!(req.uri, "https://www.example.com/");
http.enable_retry_buffering();
assert!(!http.is_body_empty());
assert!(!http.is_body_done());
let retry_body = http.get_retry_buffer();
assert!(retry_body.is_none());
let response_header = Box::new(ResponseHeader::build(200, None).unwrap());
assert!(http
.write_response_header(response_header.clone(), false)
.is_ok());
http.write_body(server_body.into(), false).await.unwrap();
assert_eq!(http.body_bytes_sent(), 16);
http.read_body_or_idle(http.is_body_done()).await.unwrap();
}));
}
for handle in handles {
assert!(handle.await.is_ok());
}
}
}