use bytes::Bytes;
use http::{Method, Uri};
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::io::WriteHalf;
use tokio::sync::{mpsc, oneshot};
use crate::error::{Error, Result};
use crate::headers::Headers;
use crate::request::RequestBody;
use crate::response::{Body, Response};
use crate::transport::connector::MaybeHttpsStream;
use crate::transport::h2::body::{H2Body, H2BodyShared, H2BodyTimeouts};
use crate::transport::h2::driver::{DriverCommand, InlineRegistration, StreamingHeadersResult};
use crate::transport::h2::tunnel::H2Tunnel;
use crate::transport::h2::write_half::H2WriteHalf;
use crate::transport::h2::H2TransportConfig;
pub(crate) struct H2InlineState {
pub(crate) write_half: Arc<H2WriteHalf<WriteHalf<MaybeHttpsStream>>>,
pub(crate) peer_max_frame_size: Arc<AtomicU32>,
pub(crate) initial_window_size: u32,
pub(crate) register_tx: mpsc::UnboundedSender<InlineRegistration>,
pub(crate) inline_active: Arc<AtomicUsize>,
pub(crate) inline_eligible: Arc<AtomicBool>,
pub(crate) body_progress_notify: Arc<tokio::sync::Notify>,
pub(crate) streaming_body_buffer_slots: usize,
}
#[derive(Clone)]
pub struct H2Handle {
command_tx: mpsc::Sender<DriverCommand>,
goaway_received: Arc<AtomicBool>,
inline: Option<Arc<H2InlineState>>,
transport_config: H2TransportConfig,
backpressure_stall_count: Arc<AtomicU64>,
}
impl H2Handle {
pub fn new(command_tx: mpsc::Sender<DriverCommand>, goaway_received: Arc<AtomicBool>) -> Self {
Self::new_with_config(
command_tx,
goaway_received,
H2TransportConfig::default(),
Arc::new(AtomicU64::new(0)),
)
}
pub(crate) fn new_with_config(
command_tx: mpsc::Sender<DriverCommand>,
goaway_received: Arc<AtomicBool>,
transport_config: H2TransportConfig,
backpressure_stall_count: Arc<AtomicU64>,
) -> Self {
Self {
command_tx,
goaway_received,
inline: None,
transport_config: transport_config.normalized(),
backpressure_stall_count,
}
}
pub(crate) fn with_inline(
command_tx: mpsc::Sender<DriverCommand>,
goaway_received: Arc<AtomicBool>,
inline: Arc<H2InlineState>,
transport_config: H2TransportConfig,
backpressure_stall_count: Arc<AtomicU64>,
) -> Self {
Self {
command_tx,
goaway_received,
inline: Some(inline),
transport_config: transport_config.normalized(),
backpressure_stall_count,
}
}
pub fn is_alive(&self) -> bool {
!self.command_tx.is_closed() && !self.goaway_received.load(Ordering::Relaxed)
}
pub fn streaming_body_buffer_slots(&self) -> usize {
self.transport_config.streaming_body_buffer_slots
}
pub fn backpressure_stall_count(&self) -> u64 {
self.backpressure_stall_count.load(Ordering::Relaxed)
}
pub async fn send_request(
&self,
method: Method,
uri: &Uri,
headers: impl Into<Headers>,
body: Option<Bytes>,
) -> Result<Response> {
let (response_tx, response_rx) = oneshot::channel();
let headers = headers.into();
let command = DriverCommand::SendRequest {
method,
uri: uri.clone(),
headers,
body,
response_tx,
};
self.command_tx
.send(command)
.await
.map_err(|_| Error::HttpProtocol("Driver channel closed".into()))?;
let stream_response = response_rx
.await
.map_err(|_| Error::HttpProtocol("Response channel closed".into()))??;
Ok(Response::new(
stream_response.status,
Headers::from(stream_response.headers),
stream_response.body,
"HTTP/2".to_string(),
))
}
pub async fn send_streaming_request(
&self,
method: Method,
uri: &Uri,
headers: impl Into<Headers>,
body: RequestBody,
body_timeouts: H2BodyTimeouts,
) -> Result<Response> {
let headers = headers.into();
let body_is_empty = body.is_empty();
if let Some(result) = self
.try_send_streaming_inline(&method, uri, &headers, body_is_empty, body_timeouts)
.await
{
return result;
}
self.send_streaming_request_command_path(method, uri, &headers, body, body_timeouts)
.await
}
async fn send_streaming_request_command_path(
&self,
method: Method,
uri: &Uri,
headers: &Headers,
body: RequestBody,
body_timeouts: H2BodyTimeouts,
) -> Result<Response> {
let (headers_tx, headers_rx) = oneshot::channel();
let (trailers_tx, trailers_rx) = if wants_trailers(headers) {
let (tx, rx) = oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let initial_window_size = self
.inline
.as_ref()
.map(|inline| inline.initial_window_size)
.unwrap_or(65_535);
let body_shared = H2BodyShared::new_with_capacity(
self.body_progress_notify(),
initial_window_size,
self.transport_config.streaming_body_buffer_slots,
);
let command = DriverCommand::SendStreamingRequest {
method,
uri: uri.clone(),
headers: headers.clone(),
body,
body_shared: body_shared.clone(),
headers_tx,
trailers_tx,
};
self.command_tx
.send(command)
.await
.map_err(|_| Error::HttpProtocol("Driver channel closed".into()))?;
let (status, regular_headers) = headers_rx
.await
.map_err(|_| Error::HttpProtocol("Headers channel closed".into()))??;
Ok(Response::with_body(
status,
Headers::from(regular_headers),
Body::from_h2(H2Body::new_with_trailers(
body_shared,
body_timeouts,
trailers_rx,
)),
"HTTP/2".to_string(),
))
}
async fn try_send_streaming_inline(
&self,
method: &Method,
uri: &Uri,
headers: &Headers,
body_is_empty: bool,
body_timeouts: H2BodyTimeouts,
) -> Option<Result<Response>> {
let inline = self.inline.as_ref()?;
if !self.is_alive() {
return None;
}
if !body_is_empty {
return None;
}
if !inline.inline_eligible.load(Ordering::Relaxed) {
return None;
}
if inline
.inline_active
.compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return None;
}
let (headers_tx, headers_rx) = oneshot::channel::<StreamingHeadersResult>();
let (trailers_tx, trailers_rx) = if wants_trailers(headers) {
let (tx, rx) = oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let body_shared = H2BodyShared::new_with_capacity(
inline.body_progress_notify.clone(),
inline.initial_window_size,
inline.streaming_body_buffer_slots,
);
let max_frame_size = inline.peer_max_frame_size.load(Ordering::Relaxed) as usize;
let stream_id = match inline
.write_half
.write_request_headers(method, uri, headers, true, max_frame_size)
.await
{
Ok(id) => id,
Err(error) => {
inline.inline_active.fetch_sub(1, Ordering::AcqRel);
return Some(Err(error));
}
};
let registration = InlineRegistration {
stream_id,
headers_tx,
body_shared: body_shared.clone(),
recv_window: inline.initial_window_size as i32,
trailers_tx,
};
if inline.register_tx.send(registration).is_err() {
inline.inline_active.fetch_sub(1, Ordering::AcqRel);
return Some(Err(Error::HttpProtocol("Driver channel closed".into())));
}
let result = match headers_rx.await {
Ok(Ok((status, regular_headers))) => Ok(Response::with_body(
status,
Headers::from(regular_headers),
Body::from_h2(H2Body::new_with_trailers(
body_shared,
body_timeouts,
trailers_rx,
)),
"HTTP/2".to_string(),
)),
Ok(Err(e)) => Err(e),
Err(_) => Err(Error::HttpProtocol("Headers channel closed".into())),
};
Some(result)
}
fn body_progress_notify(&self) -> Arc<tokio::sync::Notify> {
self.inline
.as_ref()
.map(|inline| inline.body_progress_notify.clone())
.unwrap_or_else(|| Arc::new(tokio::sync::Notify::new()))
}
pub async fn open_websocket_tunnel(
&self,
uri: Uri,
headers: impl Into<Headers>,
) -> Result<H2Tunnel> {
let (response_tx, response_rx) = oneshot::channel();
let headers = headers.into();
self.command_tx
.send(DriverCommand::OpenWebSocketTunnel {
uri,
headers: headers.to_vec(),
response_tx,
})
.await
.map_err(|_| Error::HttpProtocol("Driver channel closed".into()))?;
response_rx
.await
.map_err(|_| Error::HttpProtocol("Tunnel response channel closed".into()))?
}
}
fn wants_trailers(headers: &Headers) -> bool {
headers.get_all("te").iter().any(|value| {
value
.split(',')
.any(|token| token.trim().eq_ignore_ascii_case("trailers"))
})
}
#[cfg(test)]
mod tests {
use super::wants_trailers;
use crate::headers::Headers;
use crate::request::RequestBody;
use crate::transport::h2::body::H2BodyShared;
use crate::transport::h2::driver::{DriverCommand, InlineRegistration};
use std::sync::Arc;
use tokio::sync::{oneshot, Notify};
#[test]
fn wants_trailers_false_without_te() {
let headers = Headers::from_vec(vec![(
"content-type".to_string(),
"application/grpc+proto".to_string(),
)]);
assert!(!wants_trailers(&headers));
}
#[test]
fn wants_trailers_true_for_te_trailers() {
let headers = Headers::from_vec(vec![("te".to_string(), "trailers".to_string())]);
assert!(wants_trailers(&headers));
}
#[test]
fn wants_trailers_true_in_te_list_and_case_insensitive() {
let headers = Headers::from_vec(vec![("TE".to_string(), "deflate, Trailers".to_string())]);
assert!(wants_trailers(&headers));
}
#[test]
fn wants_trailers_false_for_unrelated_te() {
let headers = Headers::from_vec(vec![("te".to_string(), "deflate, gzip".to_string())]);
assert!(!wants_trailers(&headers));
}
#[test]
fn wants_trailers_true_for_separate_te_lines() {
let headers = Headers::from_vec(vec![
("te".to_string(), "deflate".to_string()),
("te".to_string(), "trailers".to_string()),
]);
assert!(wants_trailers(&headers));
}
fn make_body_shared() -> Arc<H2BodyShared> {
H2BodyShared::new_with_capacity(Arc::new(Notify::new()), 65_535, 16)
}
#[test]
fn send_streaming_request_command_no_te_has_no_trailers_tx() {
let headers_without_te = Headers::from_vec(vec![(
"content-type".to_string(),
"application/grpc+proto".to_string(),
)]);
let (headers_tx, _headers_rx) = oneshot::channel();
let (trailers_tx, _trailers_rx) = if wants_trailers(&headers_without_te) {
let (tx, rx) = oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let command = DriverCommand::SendStreamingRequest {
method: http::Method::POST,
uri: "https://example.com/svc/method".parse().unwrap(),
headers: headers_without_te,
body: RequestBody::Empty,
body_shared: make_body_shared(),
headers_tx,
trailers_tx,
};
if let DriverCommand::SendStreamingRequest { trailers_tx, .. } = command {
assert!(
trailers_tx.is_none(),
"no te:trailers -> trailers_tx must be None"
);
}
}
#[test]
fn send_streaming_request_command_with_te_has_trailers_tx() {
let headers_with_te = Headers::from_vec(vec![("te".to_string(), "trailers".to_string())]);
let (headers_tx, _headers_rx) = oneshot::channel();
let (trailers_tx, _trailers_rx) = if wants_trailers(&headers_with_te) {
let (tx, rx) = oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let command = DriverCommand::SendStreamingRequest {
method: http::Method::POST,
uri: "https://example.com/svc/method".parse().unwrap(),
headers: headers_with_te,
body: RequestBody::Empty,
body_shared: make_body_shared(),
headers_tx,
trailers_tx,
};
if let DriverCommand::SendStreamingRequest { trailers_tx, .. } = command {
assert!(
trailers_tx.is_some(),
"te:trailers -> trailers_tx must be Some"
);
}
}
#[test]
fn inline_registration_no_te_has_no_trailers_tx() {
let headers_without_te = Headers::from_vec(vec![(
"content-type".to_string(),
"application/grpc+proto".to_string(),
)]);
let (headers_tx, _headers_rx) = oneshot::channel();
let (trailers_tx, _trailers_rx) = if wants_trailers(&headers_without_te) {
let (tx, rx) = oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let reg = InlineRegistration {
stream_id: 1,
headers_tx,
body_shared: make_body_shared(),
recv_window: 65_535,
trailers_tx,
};
assert!(
reg.trailers_tx.is_none(),
"no te:trailers -> InlineRegistration::trailers_tx must be None"
);
}
#[test]
fn inline_registration_with_te_has_trailers_tx() {
let headers_with_te = Headers::from_vec(vec![("te".to_string(), "trailers".to_string())]);
let (headers_tx, _headers_rx) = oneshot::channel();
let (trailers_tx, _trailers_rx) = if wants_trailers(&headers_with_te) {
let (tx, rx) = oneshot::channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let reg = InlineRegistration {
stream_id: 1,
headers_tx,
body_shared: make_body_shared(),
recv_window: 65_535,
trailers_tx,
};
assert!(
reg.trailers_tx.is_some(),
"te:trailers -> InlineRegistration::trailers_tx must be Some"
);
}
}