use bytes::Bytes;
use h2::client::{self, ResponseFuture, SendRequest};
use h2::{Reason, RecvStream, SendStream};
use http::HeaderMap;
use log::{debug, error, warn};
use pingora_error::{Error, ErrorType, ErrorType::*, OrErr, Result, RetryType};
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_timeout::timeout;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::watch;
use crate::connectors::http::v2::ConnectionRef;
use crate::protocols::{Digest, SocketAddr, UniqueIDType};
pub const PING_TIMEDOUT: ErrorType = ErrorType::new("PingTimedout");
pub struct Http2Session {
send_req: SendRequest<Bytes>,
send_body: Option<SendStream<Bytes>>,
resp_fut: Option<ResponseFuture>,
req_sent: Option<Box<RequestHeader>>,
response_header: Option<ResponseHeader>,
response_body_reader: Option<RecvStream>,
pub read_timeout: Option<Duration>,
pub(crate) conn: ConnectionRef,
ended: bool,
}
impl Drop for Http2Session {
fn drop(&mut self) {
self.conn.release_stream();
}
}
impl Http2Session {
pub(crate) fn new(send_req: SendRequest<Bytes>, conn: ConnectionRef) -> Self {
Http2Session {
send_req,
send_body: None,
resp_fut: None,
req_sent: None,
response_header: None,
response_body_reader: None,
read_timeout: None,
conn,
ended: false,
}
}
fn sanitize_request_header(req: &mut RequestHeader) -> Result<()> {
req.set_version(http::Version::HTTP_2);
if req.uri.authority().is_some() {
return Ok(());
}
let Some(authority) = req.headers.get(http::header::HOST).map(|v| v.as_bytes()) else {
return Error::e_explain(InvalidHTTPHeader, "no authority header for h2");
};
let uri = http::uri::Builder::new()
.scheme("https") .authority(authority)
.path_and_query(req.uri.path_and_query().as_ref().unwrap().as_str())
.build();
match uri {
Ok(uri) => {
req.set_uri(uri);
Ok(())
}
Err(_) => Error::e_explain(
InvalidHTTPHeader,
format!("invalid authority from host {authority:?}"),
),
}
}
pub fn write_request_header(&mut self, mut req: Box<RequestHeader>, end: bool) -> Result<()> {
if self.req_sent.is_some() {
return Ok(());
}
Self::sanitize_request_header(&mut req)?;
let parts = req.as_owned_parts();
let request = http::Request::from_parts(parts, ());
let (resp_fut, send_body) = self
.send_req
.send_request(request, end)
.or_err(H2Error, "while sending request")
.map_err(|e| self.handle_err(e))?;
self.req_sent = Some(req);
self.send_body = Some(send_body);
self.resp_fut = Some(resp_fut);
self.ended = self.ended || end;
Ok(())
}
pub fn write_request_body(&mut self, data: Bytes, end: bool) -> Result<()> {
if self.ended {
warn!("Try to write request body after end of stream, dropping the extra data");
return Ok(());
}
let body_writer = self
.send_body
.as_mut()
.expect("Try to write request body before sending request header");
write_body(body_writer, data, end).map_err(|e| self.handle_err(e))?;
self.ended = self.ended || end;
Ok(())
}
pub fn finish_request_body(&mut self) -> Result<()> {
if self.ended {
return Ok(());
}
let body_writer = self
.send_body
.as_mut()
.expect("Try to finish request stream before sending request header");
body_writer
.send_data("".into(), true)
.or_err(WriteError, "while writing empty h2 request body")
.map_err(|e| self.handle_err(e))?;
self.ended = true;
Ok(())
}
pub async fn read_response_header(&mut self) -> Result<()> {
if self.response_header.is_some() {
panic!("H2 response header is already read")
}
let Some(resp_fut) = self.resp_fut.take() else {
panic!("Try to response header is already read")
};
let res = match self.read_timeout {
Some(t) => timeout(t, resp_fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 response header"))
.map_err(|e| self.handle_err(e))?,
None => resp_fut.await,
};
let (resp, body_reader) = res.map_err(handle_read_header_error)?.into_parts();
self.response_header = Some(resp.into());
self.response_body_reader = Some(body_reader);
Ok(())
}
pub async fn read_response_body(&mut self) -> Result<Option<Bytes>> {
let Some(body_reader) = self.response_body_reader.as_mut() else {
return Ok(None);
};
if body_reader.is_end_stream() {
return Ok(None);
}
let fut = body_reader.data();
let res = match self.read_timeout {
Some(t) => timeout(t, fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 response body"))?,
None => fut.await,
};
let body = res
.transpose()
.or_err(ReadError, "while read h2 response body")
.map_err(|mut e| {
if self.conn.ping_timedout() {
e.etype = PING_TIMEDOUT;
}
e
})?;
if let Some(data) = body.as_ref() {
body_reader
.flow_control()
.release_capacity(data.len())
.or_err(ReadError, "while releasing h2 response body capacity")?;
}
Ok(body)
}
pub fn response_finished(&self) -> bool {
self.response_body_reader
.as_ref()
.map_or(false, |reader| reader.is_end_stream())
}
pub async fn read_trailers(&mut self) -> Result<Option<HeaderMap>> {
let Some(reader) = self.response_body_reader.as_mut() else {
return Ok(None);
};
let fut = reader.trailers();
let res = match self.read_timeout {
Some(t) => timeout(t, fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 trailer"))
.map_err(|e| self.handle_err(e))?,
None => fut.await,
};
match res {
Ok(t) => Ok(t),
Err(e) => {
if (e.is_go_away() || e.is_reset())
&& e.is_remote()
&& e.reason() == Some(Reason::NO_ERROR)
{
Ok(None)
} else {
Err(e)
}
}
}
.or_err(ReadError, "while reading h2 trailers")
}
pub fn response_header(&self) -> Option<&ResponseHeader> {
self.response_header.as_ref()
}
pub fn shutdown(&mut self) {
if !self.ended || !self.response_finished() {
if let Some(send_body) = self.send_body.as_mut() {
send_body.send_reset(h2::Reason::INTERNAL_ERROR)
}
}
}
pub(crate) fn conn(&self) -> ConnectionRef {
self.conn.clone()
}
pub(crate) fn ping_timedout(&self) -> bool {
self.conn.ping_timedout()
}
pub fn digest(&self) -> Option<&Digest> {
Some(self.conn.digest())
}
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
self.conn.digest_mut()
}
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.conn
.digest()
.socket_digest
.as_ref()
.map(|d| d.peer_addr())?
}
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.conn
.digest()
.socket_digest
.as_ref()
.map(|d| d.local_addr())?
}
pub fn fd(&self) -> UniqueIDType {
self.conn.id()
}
pub fn take_request_body_writer(&mut self) -> Option<SendStream<Bytes>> {
self.send_body.take()
}
fn handle_err(&self, mut e: Box<Error>) -> Box<Error> {
if self.ping_timedout() {
e.etype = PING_TIMEDOUT;
}
if self.response_header.is_none() {
if let Some(err) = e.root_cause().downcast_ref::<h2::Error>() {
if err.is_go_away()
&& err.is_remote()
&& err.reason().map_or(false, |r| r == h2::Reason::NO_ERROR)
{
e.retry = true.into();
}
}
}
e
}
}
pub fn write_body(send_body: &mut SendStream<Bytes>, data: Bytes, end: bool) -> Result<()> {
let data_len = data.len();
send_body.reserve_capacity(data_len);
send_body
.send_data(data, end)
.or_err(WriteError, "while writing h2 request body")
}
fn handle_read_header_error(e: h2::Error) -> Box<Error> {
if e.is_remote()
&& e.reason()
.map_or(false, |r| r == h2::Reason::HTTP_1_1_REQUIRED)
{
let mut err = Error::because(H2Downgrade, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_go_away()
&& e.is_library()
&& e.reason()
.map_or(false, |r| r == h2::Reason::PROTOCOL_ERROR)
{
let mut err = Error::because(InvalidH2, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_go_away()
&& e.is_remote()
&& e.reason().map_or(false, |r| r == h2::Reason::NO_ERROR)
{
let mut err = Error::because(H2Error, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_io() {
let true_io_error = e.get_io().unwrap().raw_os_error().is_some();
let mut err = Error::because(ReadError, "while reading h2 header", e);
if true_io_error {
err.retry = RetryType::ReusedOnly;
} err
} else {
Error::because(H2Error, "while reading h2 header", e)
}
}
use tokio::sync::oneshot;
pub async fn drive_connection<S>(
mut c: client::Connection<S>,
id: UniqueIDType,
closed: watch::Sender<bool>,
ping_interval: Option<Duration>,
ping_timeout_occurred: Arc<AtomicBool>,
) where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let interval = ping_interval.unwrap_or(Duration::ZERO);
if !interval.is_zero() {
let (tx, rx) = oneshot::channel::<()>();
let dropped = Arc::new(AtomicBool::new(false));
let dropped2 = dropped.clone();
if let Some(ping_pong) = c.ping_pong() {
pingora_runtime::current_handle().spawn(async move {
do_ping_pong(ping_pong, interval, tx, dropped2, id).await;
});
} else {
warn!("Cannot get ping-pong handler from h2 connection");
}
tokio::select! {
r = c => match r {
Ok(_) => debug!("H2 connection finished fd: {id}"),
Err(e) => debug!("H2 connection fd: {id} errored: {e:?}"),
},
r = rx => match r {
Ok(_) => {
ping_timeout_occurred.store(true, Ordering::Relaxed);
warn!("H2 connection Ping timeout/Error fd: {id}, closing conn");
},
Err(e) => warn!("H2 connection Ping Rx error {e:?}"),
},
};
dropped.store(true, Ordering::Relaxed);
} else {
match c.await {
Ok(_) => debug!("H2 connection finished fd: {id}"),
Err(e) => debug!("H2 connection fd: {id} errored: {e:?}"),
}
}
let _ = closed.send(true);
}
const PING_TIMEOUT: Duration = Duration::from_secs(5);
async fn do_ping_pong(
mut ping_pong: h2::PingPong,
interval: Duration,
tx: oneshot::Sender<()>,
dropped: Arc<AtomicBool>,
id: UniqueIDType,
) {
tokio::time::sleep(interval).await;
loop {
if dropped.load(Ordering::Relaxed) {
break;
}
let ping_fut = ping_pong.ping(h2::Ping::opaque());
debug!("H2 fd: {id} ping sent");
match tokio::time::timeout(PING_TIMEOUT, ping_fut).await {
Err(_) => {
error!("H2 fd: {id} ping timeout");
let _ = tx.send(());
break;
}
Ok(r) => match r {
Ok(_) => {
debug!("H2 fd: {} pong received", id);
tokio::time::sleep(interval).await;
}
Err(e) => {
if dropped.load(Ordering::Relaxed) {
break;
}
error!("H2 fd: {id} ping error: {e}");
let _ = tx.send(());
break;
}
},
}
}
}