httproxide_client_util/
lib.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use futures_util::future::BoxFuture;
9use hyper::body::HttpBody;
10use hyper::client::connect::Connection;
11use hyper::client::HttpConnector;
12use hyper::{Client as HyperClient, Uri};
13use serde::{Deserialize, Serialize};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tower::{Service, ServiceExt};
16
17pub type Client<B> = HyperClient<Connector, B>;
18
19#[derive(Clone, Debug)]
20pub struct Connector {
21    #[cfg(feature = "https")]
22    http: hyper_rustls::HttpsConnector<HttpConnector>,
23    #[cfg(not(feature = "https"))]
24    http: HttpConnector,
25    #[cfg(feature = "unix")]
26    unix: hyperlocal::UnixConnector,
27}
28
29pub enum Stream {
30    #[cfg(feature = "https")]
31    Http(hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>),
32    #[cfg(not(feature = "https"))]
33    Http(tokio::net::TcpStream),
34    #[cfg(feature = "unix")]
35    Unix(hyperlocal::UnixStream),
36}
37
38impl Service<Uri> for Connector {
39    type Response = Stream;
40    type Error = tower::BoxError;
41    type Future = BoxFuture<'static, Result<Stream, Self::Error>>;
42
43    fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
44        Poll::Ready(Ok(()))
45    }
46
47    fn call(&mut self, dst: Uri) -> Self::Future {
48        #[cfg(feature = "unix")]
49        if dst.scheme_str() == Some("unix") {
50            let clone = self.unix.clone();
51            let unix = std::mem::replace(&mut self.unix, clone);
52            return Box::pin(async move { Ok(Stream::Unix(unix.oneshot(dst).await?)) });
53        }
54        let clone = self.http.clone();
55        let http = std::mem::replace(&mut self.http, clone);
56        Box::pin(async move { Ok(Stream::Http(http.oneshot(dst).await?)) })
57    }
58}
59
60impl AsyncRead for Stream {
61    fn poll_read(
62        self: Pin<&mut Self>,
63        cx: &mut Context,
64        buf: &mut ReadBuf,
65    ) -> Poll<std::io::Result<()>> {
66        match self.get_mut() {
67            Stream::Http(s) => Pin::new(s).poll_read(cx, buf),
68            #[cfg(feature = "unix")]
69            Stream::Unix(s) => Pin::new(s).poll_read(cx, buf),
70        }
71    }
72}
73
74impl AsyncWrite for Stream {
75    fn poll_write(
76        self: Pin<&mut Self>,
77        cx: &mut Context,
78        buf: &[u8],
79    ) -> Poll<std::io::Result<usize>> {
80        match self.get_mut() {
81            Stream::Http(ref mut s) => Pin::new(s).poll_write(cx, buf),
82            #[cfg(feature = "unix")]
83            Stream::Unix(ref mut s) => Pin::new(s).poll_write(cx, buf),
84        }
85    }
86    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
87        match self.get_mut() {
88            Stream::Http(ref mut s) => Pin::new(s).poll_flush(cx),
89            #[cfg(feature = "unix")]
90            Stream::Unix(ref mut s) => Pin::new(s).poll_flush(cx),
91        }
92    }
93    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
94        match self.get_mut() {
95            Stream::Http(ref mut s) => Pin::new(s).poll_flush(cx),
96            #[cfg(feature = "unix")]
97            Stream::Unix(ref mut s) => Pin::new(s).poll_flush(cx),
98        }
99    }
100}
101
102impl Connection for Stream {
103    fn connected(&self) -> hyper::client::connect::Connected {
104        match self {
105            Stream::Http(s) => s.connected(),
106            #[cfg(feature = "unix")]
107            Stream::Unix(s) => s.connected(),
108        }
109    }
110}
111
112#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
113pub struct ClientConfig {
114    dangerous_skip_cert_check: bool,
115}
116
117lazy_static::lazy_static! {
118    static ref CACHE: Mutex<HashMap<ClientConfig, Box<dyn std::any::Any + Send>>> = {
119        Mutex::new(HashMap::new())
120    };
121}
122
123pub fn clear_cache() {
124    *(*CACHE).lock().unwrap() = HashMap::new();
125}
126
127#[cfg(feature = "https")]
128struct NoCertVerifier {}
129
130#[cfg(feature = "https")]
131impl rustls::client::ServerCertVerifier for NoCertVerifier {
132    fn verify_server_cert(
133        &self,
134        _end_entity: &rustls::Certificate,
135        _intermediates: &[rustls::Certificate],
136        _server_name: &rustls::client::ServerName,
137        _scts: &mut dyn Iterator<Item = &[u8]>,
138        _ocsp_response: &[u8],
139        _now: std::time::SystemTime,
140    ) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
141        Ok(rustls::client::ServerCertVerified::assertion())
142    }
143}
144
145fn new_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
146where
147    B: HttpBody + Send + 'static,
148    B::Data: Send,
149{
150    let mut http = HttpConnector::new();
151
152    #[cfg(feature = "https")]
153    let https = {
154        use hyper_rustls::ConfigBuilderExt;
155
156        http.enforce_http(false);
157
158        let tls_config = {
159            rustls::ClientConfig::builder()
160                .with_safe_defaults()
161        };
162
163        let tls_config = if cfg.dangerous_skip_cert_check {
164            tls_config
165                .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
166                .with_no_client_auth()
167        } else {
168            tls_config
169                .with_native_roots()
170                .with_no_client_auth()
171        };
172
173        let tls = hyper_rustls::HttpsConnectorBuilder::new()
174            .with_tls_config(tls_config)
175            .https_or_http()
176            .enable_http1();
177
178        #[cfg(feature = "http2")]
179        {
180            tls.enable_http2().wrap_connector(http)
181        }
182
183        #[cfg(not(feature = "http2"))]
184        {
185            tls.wrap_connector(http)
186        }
187    };
188
189    let connector = Connector {
190        #[cfg(feature = "https")]
191        http: https,
192        #[cfg(not(feature = "https"))]
193        http,
194        #[cfg(feature = "unix")]
195        unix: hyperlocal::UnixConnector,
196    };
197
198    let client = HyperClient::builder()
199        .pool_idle_timeout(Duration::from_secs(30))
200        .build(connector);
201
202    Ok(client)
203}
204
205pub fn get_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
206where
207    B: HttpBody + Send + 'static,
208    B::Data: Send,
209{
210    let mut cache = (*CACHE).lock().unwrap();
211    if let Some(ref val) = cache.get(&cfg).and_then(|x| x.downcast_ref::<Client<B>>()) {
212        Ok((*val).clone())
213    } else {
214        let new_val = new_client(cfg.clone())?;
215        cache.insert(cfg, Box::new(new_val.clone()));
216        Ok(new_val)
217    }
218}