use std::net::SocketAddr;
use futures::future::BoxFuture;
use http::{Request, Response, Uri};
use hyper::body::{Body, Bytes};
use crate::{client_body::H3IncomingClient, connection::CacheSendRequestService};
pub trait H3Connector: Send + 'static + Clone {
type CONN: h3::quic::Connection<
Bytes,
OpenStreams = Self::OS,
SendStream = Self::SS,
RecvStream = Self::RS,
OpenError = Self::OE,
> + Send;
type OS: h3::quic::OpenStreams<Bytes, OpenError = Self::OE, BidiStream = Self::BS>
+ Clone
+ Send; type SS: h3::quic::SendStream<Bytes> + Send;
type RS: h3::quic::RecvStream + Send;
type OE: Into<Box<dyn std::error::Error>> + Send;
type BS: h3::quic::BidiStream<Bytes, RecvStream = Self::RS, SendStream = Self::SS> + Send;
fn connect(
&self,
) -> impl std::future::Future<Output = Result<Self::CONN, crate::Error>> + std::marker::Send;
}
pub async fn dns_resolve(uri: &Uri) -> std::io::Result<Vec<SocketAddr>> {
let host_port = uri
.authority()
.ok_or(std::io::Error::from(std::io::ErrorKind::InvalidInput))?
.as_str();
match host_port.parse::<SocketAddr>() {
Ok(addr) => Ok(vec![addr]),
Err(_) => {
tokio::net::lookup_host(host_port)
.await
.map(|a| a.collect::<Vec<_>>())
}
}
}
pub struct H3Connection<C, B>
where
C: H3Connector,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<crate::Error>,
{
#[allow(clippy::type_complexity)]
inner:
tower::util::BoxService<Request<B>, Response<H3IncomingClient<C::RS, Bytes>>, crate::Error>,
}
impl<C, B> H3Connection<C, B>
where
C: H3Connector,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<crate::Error> + Send,
{
pub fn new(connector: C, uri: Uri) -> Self {
let cache_mk_svc = CacheSendRequestService::new(connector);
let client_svc = ClientService {
inner: cache_mk_svc,
uri,
};
Self {
inner: tower::util::BoxService::new(client_svc),
}
}
}
impl<C, B> tower::Service<Request<B>> for H3Connection<C, B>
where
C: H3Connector,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<crate::Error>,
{
type Response = Response<H3IncomingClient<C::RS, Bytes>>;
type Error = crate::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
tower::Service::poll_ready(&mut self.inner, cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
self.inner.call(req)
}
}
pub struct ClientService<C>
where
C: H3Connector,
{
inner: CacheSendRequestService<C>,
uri: Uri,
}
impl<C, B> tower::Service<Request<B>> for ClientService<C>
where
C: H3Connector,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<crate::Error> + Send,
{
type Response = Response<H3IncomingClient<C::RS, Bytes>>;
type Error = crate::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let uri = &self.uri;
let uri2 = Uri::builder()
.scheme(uri.scheme().unwrap().clone())
.authority(uri.authority().unwrap().clone())
.path_and_query(req.uri().path_and_query().unwrap().clone())
.build()
.unwrap();
*req.uri_mut() = uri2;
let fut = self.inner.call(());
Box::pin(async move {
let mut send_request = fut.await?;
send_request.call(req).await
})
}
}
pub struct H3Client<C, B>
where
C: H3Connector,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<crate::Error> + Send,
{
channel: H3Connection<C, B>,
}
impl<C, B> H3Client<C, B>
where
C: H3Connector,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<crate::Error> + Send,
{
pub fn new(inner: H3Connection<C, B>) -> Self {
Self { channel: inner }
}
pub async fn send(
&mut self,
req: http::Request<B>,
) -> Result<Response<H3IncomingClient<C::RS, Bytes>>, crate::Error> {
use tower::Service;
futures::future::poll_fn(|cx| self.channel.poll_ready(cx)).await?;
self.channel.call(req).await
}
}