httproxide_client_util/
lib.rs1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use futures_util::future::BoxFuture;
9use hyper::body::HttpBody;
10use hyper::client::connect::Connection;
11use hyper::client::HttpConnector;
12use hyper::{Client as HyperClient, Uri};
13use serde::{Deserialize, Serialize};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tower::{Service, ServiceExt};
16
17pub type Client<B> = HyperClient<Connector, B>;
18
19#[derive(Clone, Debug)]
20pub struct Connector {
21 #[cfg(feature = "https")]
22 http: hyper_rustls::HttpsConnector<HttpConnector>,
23 #[cfg(not(feature = "https"))]
24 http: HttpConnector,
25 #[cfg(feature = "unix")]
26 unix: hyperlocal::UnixConnector,
27}
28
29pub enum Stream {
30 #[cfg(feature = "https")]
31 Http(hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>),
32 #[cfg(not(feature = "https"))]
33 Http(tokio::net::TcpStream),
34 #[cfg(feature = "unix")]
35 Unix(hyperlocal::UnixStream),
36}
37
38impl Service<Uri> for Connector {
39 type Response = Stream;
40 type Error = tower::BoxError;
41 type Future = BoxFuture<'static, Result<Stream, Self::Error>>;
42
43 fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
44 Poll::Ready(Ok(()))
45 }
46
47 fn call(&mut self, dst: Uri) -> Self::Future {
48 #[cfg(feature = "unix")]
49 if dst.scheme_str() == Some("unix") {
50 let clone = self.unix.clone();
51 let unix = std::mem::replace(&mut self.unix, clone);
52 return Box::pin(async move { Ok(Stream::Unix(unix.oneshot(dst).await?)) });
53 }
54 let clone = self.http.clone();
55 let http = std::mem::replace(&mut self.http, clone);
56 Box::pin(async move { Ok(Stream::Http(http.oneshot(dst).await?)) })
57 }
58}
59
60impl AsyncRead for Stream {
61 fn poll_read(
62 self: Pin<&mut Self>,
63 cx: &mut Context,
64 buf: &mut ReadBuf,
65 ) -> Poll<std::io::Result<()>> {
66 match self.get_mut() {
67 Stream::Http(s) => Pin::new(s).poll_read(cx, buf),
68 #[cfg(feature = "unix")]
69 Stream::Unix(s) => Pin::new(s).poll_read(cx, buf),
70 }
71 }
72}
73
74impl AsyncWrite for Stream {
75 fn poll_write(
76 self: Pin<&mut Self>,
77 cx: &mut Context,
78 buf: &[u8],
79 ) -> Poll<std::io::Result<usize>> {
80 match self.get_mut() {
81 Stream::Http(ref mut s) => Pin::new(s).poll_write(cx, buf),
82 #[cfg(feature = "unix")]
83 Stream::Unix(ref mut s) => Pin::new(s).poll_write(cx, buf),
84 }
85 }
86 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
87 match self.get_mut() {
88 Stream::Http(ref mut s) => Pin::new(s).poll_flush(cx),
89 #[cfg(feature = "unix")]
90 Stream::Unix(ref mut s) => Pin::new(s).poll_flush(cx),
91 }
92 }
93 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
94 match self.get_mut() {
95 Stream::Http(ref mut s) => Pin::new(s).poll_flush(cx),
96 #[cfg(feature = "unix")]
97 Stream::Unix(ref mut s) => Pin::new(s).poll_flush(cx),
98 }
99 }
100}
101
102impl Connection for Stream {
103 fn connected(&self) -> hyper::client::connect::Connected {
104 match self {
105 Stream::Http(s) => s.connected(),
106 #[cfg(feature = "unix")]
107 Stream::Unix(s) => s.connected(),
108 }
109 }
110}
111
112#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
113pub struct ClientConfig {
114 dangerous_skip_cert_check: bool,
115}
116
117lazy_static::lazy_static! {
118 static ref CACHE: Mutex<HashMap<ClientConfig, Box<dyn std::any::Any + Send>>> = {
119 Mutex::new(HashMap::new())
120 };
121}
122
123pub fn clear_cache() {
124 *(*CACHE).lock().unwrap() = HashMap::new();
125}
126
127#[cfg(feature = "https")]
128struct NoCertVerifier {}
129
130#[cfg(feature = "https")]
131impl rustls::client::ServerCertVerifier for NoCertVerifier {
132 fn verify_server_cert(
133 &self,
134 _end_entity: &rustls::Certificate,
135 _intermediates: &[rustls::Certificate],
136 _server_name: &rustls::client::ServerName,
137 _scts: &mut dyn Iterator<Item = &[u8]>,
138 _ocsp_response: &[u8],
139 _now: std::time::SystemTime,
140 ) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
141 Ok(rustls::client::ServerCertVerified::assertion())
142 }
143}
144
145fn new_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
146where
147 B: HttpBody + Send + 'static,
148 B::Data: Send,
149{
150 let mut http = HttpConnector::new();
151
152 #[cfg(feature = "https")]
153 let https = {
154 use hyper_rustls::ConfigBuilderExt;
155
156 http.enforce_http(false);
157
158 let tls_config = {
159 rustls::ClientConfig::builder()
160 .with_safe_defaults()
161 };
162
163 let tls_config = if cfg.dangerous_skip_cert_check {
164 tls_config
165 .with_custom_certificate_verifier(Arc::new(NoCertVerifier {}))
166 .with_no_client_auth()
167 } else {
168 tls_config
169 .with_native_roots()
170 .with_no_client_auth()
171 };
172
173 let tls = hyper_rustls::HttpsConnectorBuilder::new()
174 .with_tls_config(tls_config)
175 .https_or_http()
176 .enable_http1();
177
178 #[cfg(feature = "http2")]
179 {
180 tls.enable_http2().wrap_connector(http)
181 }
182
183 #[cfg(not(feature = "http2"))]
184 {
185 tls.wrap_connector(http)
186 }
187 };
188
189 let connector = Connector {
190 #[cfg(feature = "https")]
191 http: https,
192 #[cfg(not(feature = "https"))]
193 http,
194 #[cfg(feature = "unix")]
195 unix: hyperlocal::UnixConnector,
196 };
197
198 let client = HyperClient::builder()
199 .pool_idle_timeout(Duration::from_secs(30))
200 .build(connector);
201
202 Ok(client)
203}
204
205pub fn get_client<B>(cfg: ClientConfig) -> anyhow::Result<Client<B>>
206where
207 B: HttpBody + Send + 'static,
208 B::Data: Send,
209{
210 let mut cache = (*CACHE).lock().unwrap();
211 if let Some(ref val) = cache.get(&cfg).and_then(|x| x.downcast_ref::<Client<B>>()) {
212 Ok((*val).clone())
213 } else {
214 let new_val = new_client(cfg.clone())?;
215 cache.insert(cfg, Box::new(new_val.clone()));
216 Ok(new_val)
217 }
218}