httproxide_client_util/
lib.rs1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::{LazyLock, Mutex};
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures_util::future::BoxFuture;
8use hyper::body::Body;
9use hyper::rt::ReadBufCursor;
10use hyper::Uri;
11use hyper_util::client::legacy::connect::{Connection, HttpConnector};
12use hyper_util::client::legacy::Client as HyperClient;
13use hyper_util::rt::{TokioExecutor, TokioIo};
14use serde::{Deserialize, Serialize};
15use tower::{Service, ServiceExt};
16
17pub type Client<B> = HyperClient<Connector, B>;
18
19#[derive(Clone, Debug)]
20pub struct Connector {
21 #[cfg(feature = "https")]
22 http: hyper_rustls::HttpsConnector<HttpConnector>,
23 #[cfg(not(feature = "https"))]
24 http: HttpConnector,
25 #[cfg(feature = "unix")]
26 unix: Option<hyper_unix_socket::UnixSocketConnector<String>>,
27}
28
29pub enum Stream {
30 #[cfg(feature = "https")]
31 Http(hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>),
32 #[cfg(not(feature = "https"))]
33 Http(TokioIo<tokio::net::TcpStream>),
34 #[cfg(feature = "unix")]
35 Unix(hyper_unix_socket::UnixSocketConnection),
36}
37
38impl Service<Uri> for Connector {
39 type Response = Stream;
40 type Error = tower::BoxError;
41 type Future = BoxFuture<'static, Result<Stream, Self::Error>>;
42
43 fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
44 Poll::Ready(Ok(()))
45 }
46
47 fn call(&mut self, dst: Uri) -> Self::Future {
48 #[cfg(feature = "unix")]
49 if dst.scheme_str() == Some("unix") {
50 if let Some(unix_ref) = self.unix.as_mut() {
51 let clone = unix_ref.clone();
52 let unix = std::mem::replace(&mut *unix_ref, clone);
53 return Box::pin(async move { Ok(Stream::Unix(unix.oneshot(dst).await?)) });
54 }
55 }
56 let clone = self.http.clone();
57 let http = std::mem::replace(&mut self.http, clone);
58 Box::pin(async move { Ok(Stream::Http(http.oneshot(dst).await?)) })
59 }
60}
61
62impl hyper::rt::Read for Stream {
63 fn poll_read(
64 self: Pin<&mut Self>,
65 cx: &mut Context,
66 buf: ReadBufCursor,
67 ) -> Poll<std::io::Result<()>> {
68 match self.get_mut() {
69 Stream::Http(s) => Pin::new(s).poll_read(cx, buf),
70 #[cfg(feature = "unix")]
71 Stream::Unix(s) => Pin::new(s).poll_read(cx, buf),
72 }
73 }
74}
75
76impl hyper::rt::Write for Stream {
77 fn poll_write(
78 self: Pin<&mut Self>,
79 cx: &mut Context,
80 buf: &[u8],
81 ) -> Poll<std::io::Result<usize>> {
82 match self.get_mut() {
83 Stream::Http(s) => Pin::new(s).poll_write(cx, buf),
84 #[cfg(feature = "unix")]
85 Stream::Unix(s) => Pin::new(s).poll_write(cx, buf),
86 }
87 }
88 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
89 match self.get_mut() {
90 Stream::Http(s) => Pin::new(s).poll_flush(cx),
91 #[cfg(feature = "unix")]
92 Stream::Unix(s) => Pin::new(s).poll_flush(cx),
93 }
94 }
95 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
96 match self.get_mut() {
97 Stream::Http(s) => Pin::new(s).poll_flush(cx),
98 #[cfg(feature = "unix")]
99 Stream::Unix(s) => Pin::new(s).poll_flush(cx),
100 }
101 }
102}
103
104impl Connection for Stream {
105 fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
106 match self {
107 Stream::Http(s) => s.connected(),
108 #[cfg(feature = "unix")]
109 Stream::Unix(s) => s.connected(),
110 }
111 }
112}
113
114#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
115pub struct ClientConfig {
116 #[serde(default)]
117 #[cfg(feature = "https")]
118 dangerous_skip_cert_check: bool,
119 #[serde(default)]
120 #[cfg(feature = "unix")]
121 unix_socket_path: Option<String>,
122}
123
124static CACHE: LazyLock<Mutex<HashMap<ClientConfig, Box<dyn std::any::Any + Send>>>> =
125 LazyLock::new(Default::default);
126
127pub fn clear_cache() {
128 *(*CACHE).lock().unwrap() = Default::default();
129}
130
131#[cfg(feature = "https")]
132#[derive(Debug)]
133struct NoCertVerifier {}
134
135#[cfg(feature = "https")]
136impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
137 fn verify_server_cert(
138 &self,
139 _end_entity: &rustls::pki_types::CertificateDer,
140 _intermediates: &[rustls::pki_types::CertificateDer],
141 _server_name: &rustls::pki_types::ServerName,
142 _ocsp_response: &[u8],
143 _now: rustls::pki_types::UnixTime,
144 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
145 Ok(rustls::client::danger::ServerCertVerified::assertion())
146 }
147
148 fn verify_tls12_signature(
149 &self,
150 message: &[u8],
151 cert: &rustls::pki_types::CertificateDer<'_>,
152 dss: &rustls::DigitallySignedStruct,
153 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
154 rustls::crypto::verify_tls12_signature(
155 message,
156 cert,
157 dss,
158 &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
159 )
160 }
161 fn verify_tls13_signature(
162 &self,
163 message: &[u8],
164 cert: &rustls::pki_types::CertificateDer<'_>,
165 dss: &rustls::DigitallySignedStruct,
166 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
167 rustls::crypto::verify_tls13_signature(
168 message,
169 cert,
170 dss,
171 &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
172 )
173 }
174
175 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
176 rustls::crypto::aws_lc_rs::default_provider()
177 .signature_verification_algorithms
178 .supported_schemes()
179 }
180}
181
182fn new_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
183where
184 B: Body + Send + 'static,
185 B::Data: Send,
186{
187 #[cfg(not(feature = "https"))]
188 let http = HttpConnector::new();
189
190 #[cfg(feature = "https")]
191 let http = {
192 use std::sync::Arc;
193 use hyper_rustls::ConfigBuilderExt;
194
195 let mut http = HttpConnector::new();
196 http.enforce_http(false);
197
198 let tls_config = { rustls::ClientConfig::builder() };
199
200 let tls_config = if cfg.dangerous_skip_cert_check {
201 tls_config
202 .dangerous()
203 .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
204 .with_no_client_auth()
205 } else {
206 tls_config.with_native_roots()?.with_no_client_auth()
207 };
208
209 let tls = hyper_rustls::HttpsConnectorBuilder::new()
210 .with_tls_config(tls_config)
211 .https_or_http()
212 .enable_http1();
213
214 #[cfg(feature = "http2")]
215 {
216 tls.enable_http2().wrap_connector(http)
217 }
218
219 #[cfg(not(feature = "http2"))]
220 {
221 tls.wrap_connector(http)
222 }
223 };
224
225 let connector = Connector {
226 http,
227 #[cfg(feature = "unix")]
228 unix: cfg
229 .unix_socket_path
230 .map(hyper_unix_socket::UnixSocketConnector::new),
231 };
232
233 let client = HyperClient::builder(TokioExecutor::new())
234 .pool_idle_timeout(Duration::from_secs(30))
235 .build(connector);
236
237 Ok(client)
238}
239
240pub fn get_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
241where
242 B: Body + Send + 'static,
243 B::Data: Send,
244{
245 let mut cache = (*CACHE).lock().unwrap();
246 if let Some(val) = cache.get(&cfg).and_then(|x| x.downcast_ref::<Client<B>>()) {
247 Ok((val).clone())
248 } else {
249 let new_val = new_client(cfg.clone())?;
250 cache.insert(cfg, Box::new(new_val.clone()));
251 Ok(new_val)
252 }
253}