1#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
2
3use bytes::{Buf, Bytes};
4use http_body_util::{BodyExt, Empty, combinators::BoxBody};
5use hyper::{
6 Request, Response, StatusCode, Uri, Version,
7 body::{Body, Incoming},
8 client, header,
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use std::{
12 collections::HashMap,
13 future::poll_fn,
14 sync::Arc,
15 task::{Context, Poll},
16};
17use tokio::sync::Mutex;
18use tokio::{net::TcpStream, task::JoinHandle};
19
20#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
21compile_error!(
22 "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
23);
24
25#[derive(thiserror::Error, Debug)]
26pub enum Error {
27 #[error("{0} doesn't have an valid host")]
28 InvalidHost(Box<Uri>),
29 #[error(transparent)]
30 IoError(#[from] std::io::Error),
31 #[error(transparent)]
32 HyperError(#[from] hyper::Error),
33 #[error("Failed to connect to {0}, {1}")]
34 ConnectError(Box<Uri>, hyper::Error),
35
36 #[cfg(feature = "native-tls-client")]
37 #[error("Failed to connect with TLS to {0}, {1}")]
38 TlsConnectError(Box<Uri>, native_tls::Error),
39 #[cfg(feature = "native-tls-client")]
40 #[error(transparent)]
41 NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
42
43 #[cfg(feature = "rustls-client")]
44 #[error("Failed to connect with TLS to {0}, {1}")]
45 TlsConnectError(Box<Uri>, std::io::Error),
46
47 #[error("Failed to parse URI: {0}")]
48 UriParsingError(#[from] hyper::http::uri::InvalidUri),
49
50 #[error("Failed to parse URI parts: {0}")]
51 UriPartsError(#[from] hyper::http::uri::InvalidUriParts),
52
53 #[error("TLS connector initialization failed: {0}")]
54 TlsConnectorError(String),
55}
56
57pub struct Upgraded {
59 pub client: TokioIo<hyper::upgrade::Upgraded>,
61 pub server: TokioIo<hyper::upgrade::Upgraded>,
63}
64
65type DynError = Box<dyn std::error::Error + Send + Sync>;
66type PooledBody = BoxBody<Bytes, DynError>;
67type Http1Sender = hyper::client::conn::http1::SendRequest<PooledBody>;
68type Http2Sender = hyper::client::conn::http2::SendRequest<PooledBody>;
69
70#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
71enum ConnectionProtocol {
72 Http1,
73 Http2,
74}
75
76#[derive(Clone, Debug, Eq, PartialEq, Hash)]
77struct ConnectionKey {
78 host: String,
79 port: u16,
80 is_tls: bool,
81 protocol: ConnectionProtocol,
82}
83
84impl ConnectionKey {
85 fn new(host: String, port: u16, is_tls: bool, protocol: ConnectionProtocol) -> Self {
86 Self {
87 host,
88 port,
89 is_tls,
90 protocol,
91 }
92 }
93
94 fn from_uri(uri: &Uri, protocol: ConnectionProtocol) -> Result<Self, Error> {
95 let (host, port, is_tls) = host_port(uri)?;
96 Ok(ConnectionKey::new(host, port, is_tls, protocol))
97 }
98}
99
100#[derive(Clone, Default)]
101struct ConnectionPool {
102 http1: Arc<Mutex<HashMap<ConnectionKey, Vec<Http1Sender>>>>,
103 http2: Arc<Mutex<HashMap<ConnectionKey, Http2Sender>>>,
104}
105
106impl ConnectionPool {
107 async fn take_http1(&self, key: &ConnectionKey) -> Option<Http1Sender> {
108 let mut guard = self.http1.lock().await;
109 let entry = guard.get_mut(key)?;
110 while let Some(mut conn) = entry.pop() {
111 if sender_alive_http1(&mut conn).await {
112 return Some(conn);
113 }
114 }
115 if entry.is_empty() {
116 guard.remove(key);
117 }
118 None
119 }
120
121 async fn put_http1(&self, key: ConnectionKey, sender: Http1Sender) {
122 let mut guard = self.http1.lock().await;
123 guard.entry(key).or_default().push(sender);
124 }
125
126 async fn get_http2(&self, key: &ConnectionKey) -> Option<Http2Sender> {
127 let mut guard = self.http2.lock().await;
128 let mut sender = guard.get(key).cloned()?;
129
130 let alive = sender_alive_http2(&mut sender).await;
131
132 if alive {
133 Some(sender)
134 } else {
135 guard.remove(key);
136 None
137 }
138 }
139
140 async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Http2Sender) {
141 let mut guard = self.http2.lock().await;
142 guard.entry(key).or_insert(sender);
143 }
144}
145
146async fn sender_alive_http1(sender: &mut Http1Sender) -> bool {
147 poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
148}
149
150async fn sender_alive_http2(sender: &mut Http2Sender) -> bool {
151 poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
152}
153
154#[derive(Clone)]
155pub struct DefaultClient {
157 #[cfg(feature = "native-tls-client")]
158 tls_connector_no_alpn: tokio_native_tls::TlsConnector,
159 #[cfg(feature = "native-tls-client")]
160 tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
161
162 #[cfg(feature = "rustls-client")]
163 tls_connector_no_alpn: tokio_rustls::TlsConnector,
164 #[cfg(feature = "rustls-client")]
165 tls_connector_alpn_h2: tokio_rustls::TlsConnector,
166
167 pub with_upgrades: bool,
170
171 pool: ConnectionPool,
172}
173impl Default for DefaultClient {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179impl DefaultClient {
180 #[cfg(feature = "native-tls-client")]
181 pub fn new() -> Self {
182 Self::try_new().unwrap_or_else(|err| {
183 panic!("Failed to create DefaultClient: {err}");
184 })
185 }
186
187 #[cfg(feature = "native-tls-client")]
188 pub fn try_new() -> Result<Self, Error> {
189 let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().map_err(|e| {
190 Error::TlsConnectorError(format!("Failed to build no-ALPN connector: {e}"))
191 })?;
192 let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
193 .request_alpns(&["h2", "http/1.1"])
194 .build()
195 .map_err(|e| {
196 Error::TlsConnectorError(format!("Failed to build ALPN-H2 connector: {e}"))
197 })?;
198
199 Ok(Self {
200 tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
201 tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
202 with_upgrades: false,
203 pool: ConnectionPool::default(),
204 })
205 }
206
207 #[cfg(feature = "rustls-client")]
208 pub fn new() -> Self {
209 Self::try_new().unwrap_or_else(|err| {
210 panic!("Failed to create DefaultClient: {}", err);
211 })
212 }
213
214 #[cfg(feature = "rustls-client")]
215 pub fn try_new() -> Result<Self, Error> {
216 use std::sync::Arc;
217
218 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
219 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
220
221 let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
222 .with_root_certificates(root_cert_store.clone())
223 .with_no_client_auth();
224 let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
225 .with_root_certificates(root_cert_store.clone())
226 .with_no_client_auth();
227 tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
228
229 Ok(Self {
230 tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
231 tls_connector_no_alpn,
232 )),
233 tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
234 tls_connector_alpn_h2,
235 )),
236 with_upgrades: false,
237 pool: ConnectionPool::default(),
238 })
239 }
240
241 pub fn with_upgrades(mut self) -> Self {
244 self.with_upgrades = true;
245 self
246 }
247
248 #[cfg(feature = "native-tls-client")]
249 fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
250 match http_version {
251 Version::HTTP_2 => &self.tls_connector_alpn_h2,
252 _ => &self.tls_connector_no_alpn,
253 }
254 }
255
256 #[cfg(feature = "rustls-client")]
257 fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
258 match http_version {
259 Version::HTTP_2 => &self.tls_connector_alpn_h2,
260 _ => &self.tls_connector_no_alpn,
261 }
262 }
263
264 pub async fn send_request<B>(
268 &self,
269 req: Request<B>,
270 ) -> Result<
271 (
272 Response<Incoming>,
273 Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
274 ),
275 Error,
276 >
277 where
278 B: Body<Data = Bytes> + Send + Sync + 'static,
279 B::Data: Send + Buf,
280 B::Error: Into<DynError>,
281 {
282 let target_uri = req.uri().clone();
283 let mut send_request = if req.version() == Version::HTTP_2 {
284 match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2) {
285 Ok(pool_key) => {
286 if let Some(conn) = self.pool.get_http2(&pool_key).await {
287 SendRequest::Http2(conn)
288 } else {
289 self.connect(req.uri(), req.version(), Some(pool_key))
290 .await?
291 }
292 }
293 Err(err) => {
294 tracing::warn!(
295 "ConnectionKey::from_uri failed for HTTP/2 ({}): continuing without pool",
296 err
297 );
298 self.connect(req.uri(), req.version(), None).await?
299 }
300 }
301 } else {
302 match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) {
303 Ok(pool_key) => {
304 if let Some(conn) = self.pool.take_http1(&pool_key).await {
305 SendRequest::Http1(conn)
306 } else {
307 self.connect(req.uri(), req.version(), Some(pool_key))
308 .await?
309 }
310 }
311 Err(err) => {
312 tracing::warn!(
313 "ConnectionKey::from_uri failed for HTTP/1 ({}): continuing without pool",
314 err
315 );
316 self.connect(req.uri(), req.version(), None).await?
317 }
318 }
319 };
320
321 let (req_parts, req_body) = req.into_parts();
322
323 let boxed_req = Request::from_parts(req_parts.clone(), to_boxed_body(req_body));
324
325 let res = send_request.send_request(boxed_req).await?;
326
327 if res.status() == StatusCode::SWITCHING_PROTOCOLS {
328 let (res_parts, res_body) = res.into_parts();
329
330 let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
331 let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
332
333 let upgrade = if self.with_upgrades {
334 Some(tokio::task::spawn(async move {
335 let client = hyper::upgrade::on(client_request).await?;
336 let server = hyper::upgrade::on(server_response).await?;
337
338 Ok(Upgraded {
339 client: TokioIo::new(client),
340 server: TokioIo::new(server),
341 })
342 }))
343 } else {
344 tokio::task::spawn(async move {
345 let client = hyper::upgrade::on(client_request).await?;
346 let server = hyper::upgrade::on(server_response).await?;
347
348 let _ = tokio::io::copy_bidirectional(
349 &mut TokioIo::new(client),
350 &mut TokioIo::new(server),
351 )
352 .await;
353
354 Ok::<_, hyper::Error>(())
355 });
356 None
357 };
358
359 Ok((Response::from_parts(res_parts, res_body), upgrade))
360 } else {
361 match send_request {
362 SendRequest::Http1(sender) => {
363 if let Ok(pool_key) =
364 ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)
365 {
366 self.pool.put_http1(pool_key, sender).await;
367 } else {
368 }
370 }
371 SendRequest::Http2(_) => {
372 }
374 }
375 Ok((res, None))
376 }
377 }
378
379 async fn connect(
380 &self,
381 uri: &Uri,
382 http_version: Version,
383 key: Option<ConnectionKey>,
384 ) -> Result<SendRequest, Error> {
385 let (host, port, is_tls) = host_port(uri)?;
386
387 let tcp = TcpStream::connect((host.as_str(), port)).await?;
388 let _ = tcp.set_nodelay(true);
390
391 if is_tls {
392 #[cfg(feature = "native-tls-client")]
393 let tls = self
394 .tls_connector(http_version)
395 .connect(&host, tcp)
396 .await
397 .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
398 #[cfg(feature = "rustls-client")]
399 let tls = self
400 .tls_connector(http_version)
401 .connect(
402 host.to_string()
403 .try_into()
404 .map_err(|_| Error::InvalidHost(Box::new(uri.clone())))?,
405 tcp,
406 )
407 .await
408 .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
409
410 #[cfg(feature = "native-tls-client")]
411 let is_h2 = matches!(
412 tls.get_ref()
413 .negotiated_alpn()
414 .map(|a| a.map(|b| b == b"h2")),
415 Ok(Some(true))
416 );
417
418 #[cfg(feature = "rustls-client")]
419 let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
420
421 if is_h2 {
422 let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
423 .handshake(TokioIo::new(tls))
424 .await
425 .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
426
427 tokio::spawn(conn);
428
429 if let Some(ref k) = key
430 && matches!(k.protocol, ConnectionProtocol::Http2)
431 {
432 self.pool
433 .insert_http2_if_absent(k.clone(), sender.clone())
434 .await;
435 }
436
437 Ok(SendRequest::Http2(sender))
438 } else {
439 let (sender, conn) = client::conn::http1::Builder::new()
440 .preserve_header_case(true)
441 .title_case_headers(true)
442 .handshake(TokioIo::new(tls))
443 .await
444 .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
445
446 tokio::spawn(conn.with_upgrades());
447
448 Ok(SendRequest::Http1(sender))
449 }
450 } else {
451 let (sender, conn) = client::conn::http1::Builder::new()
452 .preserve_header_case(true)
453 .title_case_headers(true)
454 .handshake(TokioIo::new(tcp))
455 .await
456 .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
457 tokio::spawn(conn.with_upgrades());
458 Ok(SendRequest::Http1(sender))
459 }
460 }
461}
462
463enum SendRequest {
464 Http1(Http1Sender),
465 Http2(Http2Sender),
466}
467
468impl SendRequest {
469 async fn send_request(
470 &mut self,
471 mut req: Request<PooledBody>,
472 ) -> Result<Response<Incoming>, hyper::Error> {
473 match self {
474 SendRequest::Http1(sender) => {
475 if req.version() == hyper::Version::HTTP_2
476 && let Some(authority) = req.uri().authority().cloned()
477 {
478 match authority.as_str().parse::<header::HeaderValue>() {
479 Ok(host_value) => {
480 req.headers_mut().insert(header::HOST, host_value);
481 }
482 Err(err) => {
483 tracing::warn!(
484 "Failed to parse authority '{}' as HOST header: {}",
485 authority,
486 err
487 );
488 }
489 }
490 }
491 if let Err(err) = remove_authority(&mut req) {
492 tracing::error!("Failed to remove authority from URI: {}", err);
493 }
495 sender.send_request(req).await
496 }
497 SendRequest::Http2(sender) => {
498 if req.version() != hyper::Version::HTTP_2 {
499 req.headers_mut().remove(header::HOST);
500 }
501 sender.send_request(req).await
502 }
503 }
504 }
505}
506
507impl SendRequest {
508 #[allow(dead_code)]
509 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
511 match self {
512 SendRequest::Http1(sender) => sender.poll_ready(cx),
513 SendRequest::Http2(_sender) => Poll::Ready(Ok(())),
514 }
515 }
516}
517
518fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::InvalidUriParts> {
519 let mut parts = req.uri().clone().into_parts();
520 parts.scheme = None;
521 parts.authority = None;
522 *req.uri_mut() = Uri::from_parts(parts)?;
523 Ok(())
524}
525
526fn to_boxed_body<B>(body: B) -> PooledBody
527where
528 B: Body<Data = Bytes> + Send + Sync + 'static,
529 B::Data: Send + Buf,
530 B::Error: Into<DynError>,
531{
532 body.map_err(|err| err.into()).boxed()
533}
534
535fn host_port(uri: &Uri) -> Result<(String, u16, bool), Error> {
536 let host = uri
537 .host()
538 .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?
539 .to_string();
540 let is_tls = uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS);
541 let port = uri.port_u16().unwrap_or(if is_tls { 443 } else { 80 });
542 Ok((host, port, is_tls))
543}
544
545impl DefaultClient {}