use bytes::Bytes;
use http::{Method, Uri};
use tokio::sync::mpsc;
use crate::error::{Error, Result};
use crate::headers::Headers;
use crate::response::Response;
use crate::transport::h2::driver::DriverCommand;
use crate::transport::h2::tunnel::H2Tunnel;
#[derive(Clone)]
pub struct H2Handle {
command_tx: mpsc::Sender<DriverCommand>,
goaway_received: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
impl H2Handle {
pub fn new(
command_tx: mpsc::Sender<DriverCommand>,
goaway_received: std::sync::Arc<std::sync::atomic::AtomicBool>,
) -> Self {
Self {
command_tx,
goaway_received,
}
}
pub fn is_alive(&self) -> bool {
!self.command_tx.is_closed()
&& !self
.goaway_received
.load(std::sync::atomic::Ordering::Relaxed)
}
pub async fn send_request(
&self,
method: Method,
uri: &Uri,
headers: Vec<(String, String)>,
body: Option<Bytes>,
) -> Result<Response> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let command = DriverCommand::SendRequest {
method,
uri: uri.clone(),
headers,
body,
response_tx,
};
self.command_tx
.send(command)
.await
.map_err(|_| Error::HttpProtocol("Driver channel closed".into()))?;
let stream_response = response_rx
.await
.map_err(|_| Error::HttpProtocol("Response channel closed".into()))??;
Ok(Response::new(
stream_response.status,
Headers::from(stream_response.headers),
stream_response.body,
"HTTP/2".to_string(),
))
}
pub async fn send_streaming_request(
&self,
method: Method,
uri: &Uri,
headers: Vec<(String, String)>,
body: Option<Bytes>,
) -> Result<(Response, mpsc::Receiver<Result<Bytes>>)> {
let (headers_tx, headers_rx) = tokio::sync::oneshot::channel();
let (body_tx, body_rx) = mpsc::channel(32);
let (internal_tx, mut internal_rx) = mpsc::unbounded_channel::<Result<Bytes>>();
let body_tx_clone = body_tx.clone();
tokio::spawn(async move {
while let Some(item) = internal_rx.recv().await {
if body_tx_clone.send(item).await.is_err() {
break;
}
}
});
let command = DriverCommand::SendStreamingRequest {
method,
uri: uri.clone(),
headers,
body,
body_tx: internal_tx,
headers_tx,
};
self.command_tx
.send(command)
.await
.map_err(|_| Error::HttpProtocol("Driver channel closed".into()))?;
let (status, regular_headers) = headers_rx
.await
.map_err(|_| Error::HttpProtocol("Headers channel closed".into()))??;
Ok((
Response::new(
status,
Headers::from(regular_headers),
Bytes::new(),
"HTTP/2".to_string(),
),
body_rx,
))
}
pub async fn open_websocket_tunnel(
&self,
uri: Uri,
headers: Vec<(String, String)>,
) -> Result<H2Tunnel> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
self.command_tx
.send(DriverCommand::OpenWebSocketTunnel {
uri,
headers,
response_tx,
})
.await
.map_err(|_| Error::HttpProtocol("Driver channel closed".into()))?;
response_rx
.await
.map_err(|_| Error::HttpProtocol("Tunnel response channel closed".into()))?
}
}