use bytes::Bytes;
use futures::FutureExt;
use h2::client::{self, ResponseFuture, SendRequest};
use h2::{Reason, RecvStream, SendStream};
use http::HeaderMap;
use log::{debug, error, warn};
use pingora_error::{Error, ErrorType, ErrorType::*, OrErr, Result, RetryType};
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_timeout::timeout;
use std::io::ErrorKind;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::watch;
use crate::connectors::http::v2::ConnectionRef;
use crate::protocols::{Digest, SocketAddr, UniqueIDType};
pub const PING_TIMEDOUT: ErrorType = ErrorType::new("PingTimedout");
pub struct Http2Session {
send_req: SendRequest<Bytes>,
send_body: Option<SendStream<Bytes>>,
resp_fut: Option<ResponseFuture>,
req_sent: Option<Box<RequestHeader>>,
response_header: Option<ResponseHeader>,
response_body_reader: Option<RecvStream>,
pub read_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
pub conn: ConnectionRef,
ended: bool,
body_recv: usize,
}
impl Drop for Http2Session {
fn drop(&mut self) {
self.conn.release_stream();
}
}
impl Http2Session {
pub(crate) fn new(send_req: SendRequest<Bytes>, conn: ConnectionRef) -> Self {
Http2Session {
send_req,
send_body: None,
resp_fut: None,
req_sent: None,
response_header: None,
response_body_reader: None,
read_timeout: None,
write_timeout: None,
conn,
ended: false,
body_recv: 0,
}
}
fn sanitize_request_header(req: &mut RequestHeader) -> Result<()> {
req.set_version(http::Version::HTTP_2);
if req.uri.authority().is_some() {
return Ok(());
}
let Some(authority) = req.headers.get(http::header::HOST).map(|v| v.as_bytes()) else {
return Error::e_explain(InvalidHTTPHeader, "no authority header for h2");
};
let uri = http::uri::Builder::new()
.scheme("https") .authority(authority)
.path_and_query(req.uri.path_and_query().as_ref().unwrap().as_str())
.build();
match uri {
Ok(uri) => {
req.set_uri(uri);
Ok(())
}
Err(_) => Error::e_explain(
InvalidHTTPHeader,
format!("invalid authority from host {authority:?}"),
),
}
}
pub fn write_request_header(&mut self, mut req: Box<RequestHeader>, end: bool) -> Result<()> {
if self.req_sent.is_some() {
return Ok(());
}
Self::sanitize_request_header(&mut req)?;
let parts = req.as_owned_parts();
let request = http::Request::from_parts(parts, ());
let (resp_fut, send_body) = self
.send_req
.send_request(request, end)
.or_err(H2Error, "while sending request")
.map_err(|e| self.handle_err(e))?;
self.req_sent = Some(req);
self.send_body = Some(send_body);
self.resp_fut = Some(resp_fut);
self.ended = self.ended || end;
Ok(())
}
pub async fn write_request_body(&mut self, data: Bytes, end: bool) -> Result<()> {
if self.ended {
warn!("Try to write request body after end of stream, dropping the extra data");
return Ok(());
}
let body_writer = self
.send_body
.as_mut()
.expect("Try to write request body before sending request header");
super::write_body(body_writer, data, end, self.write_timeout)
.await
.map_err(|e| self.handle_err(e))?;
self.ended = self.ended || end;
Ok(())
}
pub fn finish_request_body(&mut self) -> Result<()> {
if self.ended {
return Ok(());
}
let body_writer = self
.send_body
.as_mut()
.expect("Try to finish request stream before sending request header");
body_writer
.send_data("".into(), true)
.or_err(WriteError, "while writing empty h2 request body")
.map_err(|e| self.handle_err(e))?;
self.ended = true;
Ok(())
}
pub async fn read_response_header(&mut self) -> Result<()> {
if self.response_header.is_some() {
panic!("H2 response header is already read")
}
let Some(resp_fut) = self.resp_fut.take() else {
panic!("Try to take response header, but it is already taken")
};
let res = match self.read_timeout {
Some(t) => timeout(t, resp_fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 response header"))
.map_err(|e| self.handle_err(e))?,
None => resp_fut.await,
};
let (resp, body_reader) = res.map_err(handle_read_header_error)?.into_parts();
self.response_header = Some(resp.into());
self.response_body_reader = Some(body_reader);
Ok(())
}
#[doc(hidden)]
pub fn poll_read_response_header(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), h2::Error>> {
if self.response_header.is_some() {
panic!("H2 response header is already read")
}
let Some(mut resp_fut) = self.resp_fut.take() else {
panic!("Try to take response header, but it is already taken")
};
let res = match resp_fut.poll_unpin(cx) {
Poll::Ready(Ok(res)) => res,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
self.resp_fut = Some(resp_fut);
return Poll::Pending;
}
};
let (resp, body_reader) = res.into_parts();
self.response_header = Some(resp.into());
self.response_body_reader = Some(body_reader);
Poll::Ready(Ok(()))
}
pub async fn read_response_body(&mut self) -> Result<Option<Bytes>> {
let Some(body_reader) = self.response_body_reader.as_mut() else {
return Ok(None);
};
let fut = body_reader.data();
let res = match self.read_timeout {
Some(t) => timeout(t, fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 response body"))?,
None => fut.await,
};
let body = res
.transpose()
.or_err(ReadError, "while read h2 response body")
.map_err(|mut e| {
if self.conn.ping_timedout() {
e.etype = PING_TIMEDOUT;
}
e
})?;
if let Some(data) = body.as_ref() {
body_reader
.flow_control()
.release_capacity(data.len())
.or_err(ReadError, "while releasing h2 response body capacity")?;
self.body_recv = self.body_recv.saturating_add(data.len());
}
Ok(body)
}
#[doc(hidden)]
pub fn poll_read_response_body(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, h2::Error>>> {
let Some(body_reader) = self.response_body_reader.as_mut() else {
return Poll::Ready(None);
};
let data = match ready!(body_reader.poll_data(cx)).transpose() {
Ok(data) => data,
Err(err) => return Poll::Ready(Some(Err(err))),
};
if let Some(data) = data {
body_reader.flow_control().release_capacity(data.len())?;
return Poll::Ready(Some(Ok(data)));
}
Poll::Ready(None)
}
pub fn response_finished(&self) -> bool {
self.response_body_reader
.as_ref()
.is_some_and(|reader| reader.is_end_stream())
}
pub fn check_response_end_or_error(&mut self) -> Result<bool> {
let Some(reader) = self.response_body_reader.as_mut() else {
return Ok(false);
};
if !reader.is_end_stream() {
return Ok(false);
}
match tokio::task::unconstrained(reader.data()).now_or_never() {
Some(None) => Ok(true),
Some(Some(Ok(_))) => Error::e_explain(H2Error, "unexpected data after end stream"),
Some(Some(Err(e))) => Error::e_because(H2Error, "while checking end stream", e),
None => {
panic!("data() not ready after end stream")
}
}
}
pub async fn read_trailers(&mut self) -> Result<Option<HeaderMap>> {
let Some(reader) = self.response_body_reader.as_mut() else {
return Ok(None);
};
let fut = reader.trailers();
let res = match self.read_timeout {
Some(t) => timeout(t, fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 trailer"))
.map_err(|e| self.handle_err(e))?,
None => fut.await,
};
match res {
Ok(t) => Ok(t),
Err(e) => {
if (e.is_go_away() || e.is_reset())
&& e.is_remote()
&& e.reason() == Some(Reason::NO_ERROR)
{
Ok(None)
} else {
Err(e)
}
}
}
.or_err(ReadError, "while reading h2 trailers")
}
pub fn request_header(&self) -> Option<&RequestHeader> {
self.req_sent.as_deref()
}
pub fn response_header(&self) -> Option<&ResponseHeader> {
self.response_header.as_ref()
}
pub fn shutdown(&mut self) {
if !self.ended || !self.response_finished() {
if let Some(send_body) = self.send_body.as_mut() {
send_body.send_reset(h2::Reason::INTERNAL_ERROR)
}
}
}
pub(crate) fn conn(&self) -> ConnectionRef {
self.conn.clone()
}
pub(crate) fn ping_timedout(&self) -> bool {
self.conn.ping_timedout()
}
pub fn digest(&self) -> Option<&Digest> {
Some(self.conn.digest())
}
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
self.conn.digest_mut()
}
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.conn
.digest()
.socket_digest
.as_ref()
.map(|d| d.peer_addr())?
}
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.conn
.digest()
.socket_digest
.as_ref()
.map(|d| d.local_addr())?
}
pub fn fd(&self) -> UniqueIDType {
self.conn.id()
}
pub fn body_bytes_received(&self) -> usize {
self.body_recv
}
pub fn take_request_body_writer(&mut self) -> Option<SendStream<Bytes>> {
self.send_body.take()
}
fn handle_err(&self, mut e: Box<Error>) -> Box<Error> {
if self.ping_timedout() {
e.etype = PING_TIMEDOUT;
}
if self.response_header.is_none() {
if let Some(err) = e.root_cause().downcast_ref::<h2::Error>() {
if err.is_go_away()
&& err.is_remote()
&& (err.reason() == Some(h2::Reason::NO_ERROR))
{
e.retry = true.into();
}
}
}
e
}
}
fn handle_read_header_error(e: h2::Error) -> Box<Error> {
if e.is_remote() && (e.reason() == Some(h2::Reason::HTTP_1_1_REQUIRED)) {
let mut err = Error::because(H2Downgrade, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_go_away() && e.is_library() && (e.reason() == Some(h2::Reason::PROTOCOL_ERROR)) {
let mut err = Error::because(InvalidH2, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_go_away() && e.is_remote() && (e.reason() == Some(h2::Reason::NO_ERROR)) {
let mut err = Error::because(H2Error, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_reset() && e.is_remote() && (e.reason() == Some(h2::Reason::REFUSED_STREAM)) {
let mut err = Error::because(H2Error, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_io() {
let io_err = e.get_io().expect("checked is io");
let true_io_error = io_err.raw_os_error().is_some()
|| matches!(
io_err.kind(),
ErrorKind::ConnectionReset | ErrorKind::TimedOut | ErrorKind::BrokenPipe
);
let mut err = Error::because(ReadError, "while reading h2 header", e);
if true_io_error {
err.retry = RetryType::ReusedOnly;
} err
} else {
Error::because(H2Error, "while reading h2 header", e)
}
}
use tokio::sync::oneshot;
pub async fn drive_connection<S>(
mut c: client::Connection<S>,
id: UniqueIDType,
closed: watch::Sender<bool>,
ping_interval: Option<Duration>,
ping_timeout_occurred: Arc<AtomicBool>,
) where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let interval = ping_interval.unwrap_or(Duration::ZERO);
if !interval.is_zero() {
let (tx, rx) = oneshot::channel::<()>();
let dropped = Arc::new(AtomicBool::new(false));
let dropped2 = dropped.clone();
if let Some(ping_pong) = c.ping_pong() {
pingora_runtime::current_handle().spawn(async move {
do_ping_pong(ping_pong, interval, tx, dropped2, id).await;
});
} else {
warn!("Cannot get ping-pong handler from h2 connection");
}
tokio::select! {
r = c => match r {
Ok(_) => debug!("H2 connection finished fd: {id}"),
Err(e) => debug!("H2 connection fd: {id} errored: {e:?}"),
},
r = rx => match r {
Ok(_) => {
ping_timeout_occurred.store(true, Ordering::Relaxed);
warn!("H2 connection Ping timeout/Error fd: {id}, closing conn");
},
Err(e) => warn!("H2 connection Ping Rx error {e:?}"),
},
};
dropped.store(true, Ordering::Relaxed);
} else {
match c.await {
Ok(_) => debug!("H2 connection finished fd: {id}"),
Err(e) => debug!("H2 connection fd: {id} errored: {e:?}"),
}
}
let _ = closed.send(true);
}
const PING_TIMEOUT: Duration = Duration::from_secs(5);
async fn do_ping_pong(
mut ping_pong: h2::PingPong,
interval: Duration,
tx: oneshot::Sender<()>,
dropped: Arc<AtomicBool>,
id: UniqueIDType,
) {
tokio::time::sleep(interval).await;
loop {
if dropped.load(Ordering::Relaxed) {
break;
}
let ping_fut = ping_pong.ping(h2::Ping::opaque());
debug!("H2 fd: {id} ping sent");
match tokio::time::timeout(PING_TIMEOUT, ping_fut).await {
Err(_) => {
error!("H2 fd: {id} ping timeout");
let _ = tx.send(());
break;
}
Ok(r) => match r {
Ok(_) => {
debug!("H2 fd: {} pong received", id);
tokio::time::sleep(interval).await;
}
Err(e) => {
if dropped.load(Ordering::Relaxed) {
break;
}
error!("H2 fd: {id} ping error: {e}");
let _ = tx.send(());
break;
}
},
}
}
}
#[cfg(test)]
mod tests_h2 {
use super::*;
use bytes::Bytes;
use http::{Response, StatusCode};
use tokio::io::duplex;
#[tokio::test]
async fn h2_body_bytes_received_multi_frames() {
let (client_io, server_io) = duplex(65536);
tokio::spawn(async move {
let mut conn = h2::server::handshake(server_io).await.unwrap();
if let Some(result) = conn.accept().await {
let (req, mut send_resp) = result.unwrap();
assert_eq!(req.method(), http::Method::GET);
let resp = Response::builder().status(StatusCode::OK).body(()).unwrap();
let mut send_stream = send_resp.send_response(resp, false).unwrap();
send_stream.send_data(Bytes::from("a"), false).unwrap();
send_stream.send_data(Bytes::from("bc"), true).unwrap();
conn.graceful_shutdown();
}
while let Some(_res) = conn.accept().await {}
});
let (send_req, connection) = h2::client::handshake(client_io).await.unwrap();
let (closed_tx, closed_rx) = tokio::sync::watch::channel(false);
let ping_timeout = Arc::new(AtomicBool::new(false));
tokio::spawn(async move {
let _ = connection.await;
let _ = closed_tx.send(true);
});
let digest = Digest::default();
let conn_ref = crate::connectors::http::v2::ConnectionRef::new(
send_req.clone(),
closed_rx,
ping_timeout,
0,
1,
digest,
);
let mut h2s = Http2Session::new(send_req, conn_ref);
let mut req = RequestHeader::build("GET", b"/", None).unwrap();
req.insert_header(http::header::HOST, "example.com")
.unwrap();
h2s.write_request_header(Box::new(req), true).unwrap();
h2s.read_response_header().await.unwrap();
let mut total = 0;
while let Some(chunk) = h2s.read_response_body().await.unwrap() {
total += chunk.len();
}
assert_eq!(total, 3);
assert_eq!(h2s.body_bytes_received(), 3);
}
}