use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::Full;
use hyper::Uri;
use hyper_tls::HttpsConnector;
use hyper_util::client::legacy::Client;
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next};
use tower::Service;
use crate::{
Conn,
error::Error::HyperClient as HyperClientError,
util::{HyperClient, body_to_stream},
};
pub struct ProxyMiddleware<C> {
client: HyperClient<C>,
}
impl<C: Conn + 'static> ProxyMiddleware<C>
where
<C as Service<Uri>>::Future: Send,
{
pub fn new(connector: C) -> Self {
let https = HttpsConnector::new_with_connector(connector);
let client: HyperClient<C> = Client::builder(hyper_util::rt::TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build(https);
Self { client }
}
}
#[async_trait]
impl<C: Conn + 'static> Middleware for ProxyMiddleware<C>
where
<C as Service<Uri>>::Future: Send,
{
async fn handle(
&self,
req: Request,
_extensions: &mut http::Extensions,
_next: Next<'_>,
) -> Result<Response, Error> {
let mut hyper_req = hyper::Request::new(Full::<Bytes>::default());
*hyper_req.method_mut() = req.method().clone();
*hyper_req.uri_mut() = req
.url()
.as_str()
.parse()
.map_err(|e| Error::Middleware(anyhow::anyhow!("Invalid URI: {e}")))?;
*hyper_req.headers_mut() = req.headers().clone();
*hyper_req.version_mut() = req.version();
let body_bytes = if let Some(body) = req.body() {
if body.as_bytes().is_none() {
return Err(Error::Middleware(anyhow::anyhow!(
"ProxyMiddleware cannot handle streaming request bodies"
)));
}
body.as_bytes().unwrap_or_default().to_vec()
} else {
Vec::new()
};
let (parts, _) = hyper_req.into_parts();
let hyper_req = hyper::Request::from_parts(parts, Full::new(Bytes::from(body_bytes)));
let http_res = self
.client
.request(hyper_req)
.await
.map_err(|e| Error::Middleware(anyhow::anyhow!(HyperClientError(e))))?;
let (parts, body) = http_res.into_parts();
let stream = body_to_stream(body);
let res = hyper::Response::from_parts(parts, reqwest::Body::wrap_stream(stream));
Ok(Response::from(res))
}
}