Skip to main content

httproxide_client_util/
lib.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::{LazyLock, Mutex};
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures_util::future::BoxFuture;
8use hyper::body::Body;
9use hyper::rt::ReadBufCursor;
10use hyper::Uri;
11use hyper_util::client::legacy::connect::{Connection, HttpConnector};
12use hyper_util::client::legacy::Client as HyperClient;
13use hyper_util::rt::{TokioExecutor, TokioIo};
14use serde::{Deserialize, Serialize};
15use tower::{Service, ServiceExt};
16
17pub type Client<B> = HyperClient<Connector, B>;
18
19#[derive(Clone, Debug)]
20pub struct Connector {
21    #[cfg(feature = "https")]
22    http: hyper_rustls::HttpsConnector<HttpConnector>,
23    #[cfg(not(feature = "https"))]
24    http: HttpConnector,
25    #[cfg(feature = "unix")]
26    unix: Option<hyper_unix_socket::UnixSocketConnector<String>>,
27}
28
29pub enum Stream {
30    #[cfg(feature = "https")]
31    Http(hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>),
32    #[cfg(not(feature = "https"))]
33    Http(TokioIo<tokio::net::TcpStream>),
34    #[cfg(feature = "unix")]
35    Unix(hyper_unix_socket::UnixSocketConnection),
36}
37
38impl Service<Uri> for Connector {
39    type Response = Stream;
40    type Error = tower::BoxError;
41    type Future = BoxFuture<'static, Result<Stream, Self::Error>>;
42
43    fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
44        Poll::Ready(Ok(()))
45    }
46
47    fn call(&mut self, dst: Uri) -> Self::Future {
48        #[cfg(feature = "unix")]
49        if dst.scheme_str() == Some("unix") {
50            if let Some(unix_ref) = self.unix.as_mut() {
51                let clone = unix_ref.clone();
52                let unix = std::mem::replace(&mut *unix_ref, clone);
53                return Box::pin(async move { Ok(Stream::Unix(unix.oneshot(dst).await?)) });
54            }
55        }
56        let clone = self.http.clone();
57        let http = std::mem::replace(&mut self.http, clone);
58        Box::pin(async move { Ok(Stream::Http(http.oneshot(dst).await?)) })
59    }
60}
61
62impl hyper::rt::Read for Stream {
63    fn poll_read(
64        self: Pin<&mut Self>,
65        cx: &mut Context,
66        buf: ReadBufCursor,
67    ) -> Poll<std::io::Result<()>> {
68        match self.get_mut() {
69            Stream::Http(s) => Pin::new(s).poll_read(cx, buf),
70            #[cfg(feature = "unix")]
71            Stream::Unix(s) => Pin::new(s).poll_read(cx, buf),
72        }
73    }
74}
75
76impl hyper::rt::Write for Stream {
77    fn poll_write(
78        self: Pin<&mut Self>,
79        cx: &mut Context,
80        buf: &[u8],
81    ) -> Poll<std::io::Result<usize>> {
82        match self.get_mut() {
83            Stream::Http(s) => Pin::new(s).poll_write(cx, buf),
84            #[cfg(feature = "unix")]
85            Stream::Unix(s) => Pin::new(s).poll_write(cx, buf),
86        }
87    }
88    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
89        match self.get_mut() {
90            Stream::Http(s) => Pin::new(s).poll_flush(cx),
91            #[cfg(feature = "unix")]
92            Stream::Unix(s) => Pin::new(s).poll_flush(cx),
93        }
94    }
95    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
96        match self.get_mut() {
97            Stream::Http(s) => Pin::new(s).poll_flush(cx),
98            #[cfg(feature = "unix")]
99            Stream::Unix(s) => Pin::new(s).poll_flush(cx),
100        }
101    }
102}
103
104impl Connection for Stream {
105    fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
106        match self {
107            Stream::Http(s) => s.connected(),
108            #[cfg(feature = "unix")]
109            Stream::Unix(s) => s.connected(),
110        }
111    }
112}
113
114#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
115pub struct ClientConfig {
116    #[serde(default)]
117    #[cfg(feature = "https")]
118    dangerous_skip_cert_check: bool,
119    #[serde(default)]
120    #[cfg(feature = "unix")]
121    unix_socket_path: Option<String>,
122}
123
124static CACHE: LazyLock<Mutex<HashMap<ClientConfig, Box<dyn std::any::Any + Send>>>> =
125    LazyLock::new(Default::default);
126
127pub fn clear_cache() {
128    *(*CACHE).lock().unwrap() = Default::default();
129}
130
131#[cfg(feature = "https")]
132#[derive(Debug)]
133struct NoCertVerifier {}
134
135#[cfg(feature = "https")]
136impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
137    fn verify_server_cert(
138        &self,
139        _end_entity: &rustls::pki_types::CertificateDer,
140        _intermediates: &[rustls::pki_types::CertificateDer],
141        _server_name: &rustls::pki_types::ServerName,
142        _ocsp_response: &[u8],
143        _now: rustls::pki_types::UnixTime,
144    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
145        Ok(rustls::client::danger::ServerCertVerified::assertion())
146    }
147
148    fn verify_tls12_signature(
149        &self,
150        message: &[u8],
151        cert: &rustls::pki_types::CertificateDer<'_>,
152        dss: &rustls::DigitallySignedStruct,
153    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
154        rustls::crypto::verify_tls12_signature(
155            message,
156            cert,
157            dss,
158            &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
159        )
160    }
161    fn verify_tls13_signature(
162        &self,
163        message: &[u8],
164        cert: &rustls::pki_types::CertificateDer<'_>,
165        dss: &rustls::DigitallySignedStruct,
166    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
167        rustls::crypto::verify_tls13_signature(
168            message,
169            cert,
170            dss,
171            &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
172        )
173    }
174
175    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
176        rustls::crypto::aws_lc_rs::default_provider()
177            .signature_verification_algorithms
178            .supported_schemes()
179    }
180}
181
182fn new_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
183where
184    B: Body + Send + 'static,
185    B::Data: Send,
186{
187    #[cfg(not(feature = "https"))]
188    let http = HttpConnector::new();
189
190    #[cfg(feature = "https")]
191    let http = {
192        use std::sync::Arc;
193        use hyper_rustls::ConfigBuilderExt;
194
195        let mut http = HttpConnector::new();
196        http.enforce_http(false);
197
198        let tls_config = { rustls::ClientConfig::builder() };
199
200        let tls_config = if cfg.dangerous_skip_cert_check {
201            tls_config
202                .dangerous()
203                .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
204                .with_no_client_auth()
205        } else {
206            tls_config.with_native_roots()?.with_no_client_auth()
207        };
208
209        let tls = hyper_rustls::HttpsConnectorBuilder::new()
210            .with_tls_config(tls_config)
211            .https_or_http()
212            .enable_http1();
213
214        #[cfg(feature = "http2")]
215        {
216            tls.enable_http2().wrap_connector(http)
217        }
218
219        #[cfg(not(feature = "http2"))]
220        {
221            tls.wrap_connector(http)
222        }
223    };
224
225    let connector = Connector {
226        http,
227        #[cfg(feature = "unix")]
228        unix: cfg
229            .unix_socket_path
230            .map(hyper_unix_socket::UnixSocketConnector::new),
231    };
232
233    let client = HyperClient::builder(TokioExecutor::new())
234        .pool_idle_timeout(Duration::from_secs(30))
235        .build(connector);
236
237    Ok(client)
238}
239
240pub fn get_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
241where
242    B: Body + Send + 'static,
243    B::Data: Send,
244{
245    let mut cache = (*CACHE).lock().unwrap();
246    if let Some(val) = cache.get(&cfg).and_then(|x| x.downcast_ref::<Client<B>>()) {
247        Ok((val).clone())
248    } else {
249        let new_val = new_client(cfg.clone())?;
250        cache.insert(cfg, Box::new(new_val.clone()));
251        Ok(new_val)
252    }
253}