use std::{pin::Pin, time::Duration};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::{future::poll_fn, stream::Stream};
use http_body_util::Full;
use hyper::body::{Body, Incoming};
use hyper_tls::HttpsConnector;
use hyper_util::client::legacy::Client;
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next};
use crate::{SsConnector, error::SsConnectorError};
type HyperClient = Client<HttpsConnector<SsConnector>, Full<Bytes>>;
fn body_to_stream(mut body: Incoming) -> impl Stream<Item = Result<Bytes, anyhow::Error>> {
async_stream::try_stream! {
while let Some(frame) = poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await {
let frame = frame.map_err(|e| anyhow::anyhow!(e))?;
if let Some(chunk) = frame.data_ref() {
yield chunk.clone();
}
}
}
}
pub struct SsMiddleware {
client: HyperClient,
}
impl SsMiddleware {
pub fn new(connector: SsConnector) -> Self {
let https = HttpsConnector::new_with_connector(connector);
let client = Client::builder(hyper_util::rt::TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build(https);
Self { client }
}
pub fn from_url(url: &str) -> Result<Self, SsConnectorError> {
let connector = SsConnector::new(url)?;
Ok(Self::new(connector))
}
}
#[async_trait]
impl Middleware for SsMiddleware {
async fn handle(
&self,
req: Request,
_extensions: &mut http::Extensions,
_next: Next<'_>,
) -> Result<Response, Error> {
let (mut parts, _) = hyper::Request::new(Full::<Bytes>::default()).into_parts();
parts.method = req.method().clone();
parts.uri = req.url().as_str().parse().unwrap();
parts.headers = req.headers().clone();
parts.version = req.version();
let body = match req.body() {
Some(body) => body.as_bytes().unwrap_or_default(),
None => &[],
};
let hyper_req = hyper::Request::from_parts(parts, Full::new(Bytes::from(body.to_vec())));
let http_res = self
.client
.request(hyper_req)
.await
.map_err(|e| Error::Middleware(anyhow::anyhow!(SsConnectorError::HyperClient(e))))?;
let (parts, body) = http_res.into_parts();
let stream = body_to_stream(body);
let res = hyper::Response::from_parts(parts, reqwest::Body::wrap_stream(stream));
Ok(Response::from(res))
}
}