1#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
2
3use bytes::Bytes;
4use http_body_util::Empty;
5use hyper::{
6 Request, Response, StatusCode, Uri, Version,
7 body::{Body, Incoming},
8 client, header,
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use std::task::{Context, Poll};
12use tokio::{net::TcpStream, task::JoinHandle};
13
14#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
15compile_error!(
16 "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
17);
18
19#[derive(thiserror::Error, Debug)]
20pub enum Error {
21 #[error("{0} doesn't have an valid host")]
22 InvalidHost(Box<Uri>),
23 #[error(transparent)]
24 IoError(#[from] std::io::Error),
25 #[error(transparent)]
26 HyperError(#[from] hyper::Error),
27 #[error("Failed to connect to {0}, {1}")]
28 ConnectError(Box<Uri>, hyper::Error),
29
30 #[cfg(feature = "native-tls-client")]
31 #[error("Failed to connect with TLS to {0}, {1}")]
32 TlsConnectError(Box<Uri>, native_tls::Error),
33 #[cfg(feature = "native-tls-client")]
34 #[error(transparent)]
35 NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
36
37 #[cfg(feature = "rustls-client")]
38 #[error("Failed to connect with TLS to {0}, {1}")]
39 TlsConnectError(Box<Uri>, std::io::Error),
40
41 #[error("Failed to parse URI: {0}")]
42 UriParsingError(#[from] hyper::http::uri::InvalidUri),
43
44 #[error("Failed to parse URI parts: {0}")]
45 UriPartsError(#[from] hyper::http::uri::InvalidUriParts),
46
47 #[error("TLS connector initialization failed: {0}")]
48 TlsConnectorError(String),
49}
50
51pub struct Upgraded {
53 pub client: TokioIo<hyper::upgrade::Upgraded>,
55 pub server: TokioIo<hyper::upgrade::Upgraded>,
57}
58#[derive(Clone)]
59pub struct DefaultClient {
61 #[cfg(feature = "native-tls-client")]
62 tls_connector_no_alpn: tokio_native_tls::TlsConnector,
63 #[cfg(feature = "native-tls-client")]
64 tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
65
66 #[cfg(feature = "rustls-client")]
67 tls_connector_no_alpn: tokio_rustls::TlsConnector,
68 #[cfg(feature = "rustls-client")]
69 tls_connector_alpn_h2: tokio_rustls::TlsConnector,
70
71 pub with_upgrades: bool,
74}
75impl Default for DefaultClient {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl DefaultClient {
82 #[cfg(feature = "native-tls-client")]
83 pub fn new() -> Self {
84 Self::try_new().unwrap_or_else(|err| {
85 panic!("Failed to create DefaultClient: {err}");
86 })
87 }
88
89 #[cfg(feature = "native-tls-client")]
90 pub fn try_new() -> Result<Self, Error> {
91 let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().map_err(|e| {
92 Error::TlsConnectorError(format!("Failed to build no-ALPN connector: {e}"))
93 })?;
94 let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
95 .request_alpns(&["h2", "http/1.1"])
96 .build()
97 .map_err(|e| {
98 Error::TlsConnectorError(format!("Failed to build ALPN-H2 connector: {e}"))
99 })?;
100
101 Ok(Self {
102 tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
103 tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
104 with_upgrades: false,
105 })
106 }
107
108 #[cfg(feature = "rustls-client")]
109 pub fn new() -> Self {
110 Self::try_new().unwrap_or_else(|err| {
111 panic!("Failed to create DefaultClient: {}", err);
112 })
113 }
114
115 #[cfg(feature = "rustls-client")]
116 pub fn try_new() -> Result<Self, Error> {
117 use std::sync::Arc;
118
119 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
120 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
121
122 let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
123 .with_root_certificates(root_cert_store.clone())
124 .with_no_client_auth();
125 let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
126 .with_root_certificates(root_cert_store.clone())
127 .with_no_client_auth();
128 tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
129
130 Ok(Self {
131 tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
132 tls_connector_no_alpn,
133 )),
134 tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
135 tls_connector_alpn_h2,
136 )),
137 with_upgrades: false,
138 })
139 }
140
141 pub fn with_upgrades(mut self) -> Self {
144 self.with_upgrades = true;
145 self
146 }
147
148 #[cfg(feature = "native-tls-client")]
149 fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
150 match http_version {
151 Version::HTTP_2 => &self.tls_connector_alpn_h2,
152 _ => &self.tls_connector_no_alpn,
153 }
154 }
155
156 #[cfg(feature = "rustls-client")]
157 fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
158 match http_version {
159 Version::HTTP_2 => &self.tls_connector_alpn_h2,
160 _ => &self.tls_connector_no_alpn,
161 }
162 }
163
164 pub async fn send_request<B>(
168 &self,
169 req: Request<B>,
170 ) -> Result<
171 (
172 Response<Incoming>,
173 Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
174 ),
175 Error,
176 >
177 where
178 B: Body + Unpin + Send + 'static,
179 B::Data: Send,
180 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
181 {
182 let mut send_request = self.connect(req.uri(), req.version()).await?;
183
184 let (req_parts, req_body) = req.into_parts();
185
186 let res = send_request
187 .send_request(Request::from_parts(req_parts.clone(), req_body))
188 .await?;
189
190 if res.status() == StatusCode::SWITCHING_PROTOCOLS {
191 let (res_parts, res_body) = res.into_parts();
192
193 let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
194 let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
195
196 let upgrade = if self.with_upgrades {
197 Some(tokio::task::spawn(async move {
198 let client = hyper::upgrade::on(client_request).await?;
199 let server = hyper::upgrade::on(server_response).await?;
200
201 Ok(Upgraded {
202 client: TokioIo::new(client),
203 server: TokioIo::new(server),
204 })
205 }))
206 } else {
207 tokio::task::spawn(async move {
208 let client = hyper::upgrade::on(client_request).await?;
209 let server = hyper::upgrade::on(server_response).await?;
210
211 let _ = tokio::io::copy_bidirectional(
212 &mut TokioIo::new(client),
213 &mut TokioIo::new(server),
214 )
215 .await;
216
217 Ok::<_, hyper::Error>(())
218 });
219 None
220 };
221
222 Ok((Response::from_parts(res_parts, res_body), upgrade))
223 } else {
224 Ok((res, None))
225 }
226 }
227
228 async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
229 where
230 B: Body + Unpin + Send + 'static,
231 B::Data: Send,
232 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
233 {
234 let host = uri
235 .host()
236 .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?;
237 let port =
238 uri.port_u16()
239 .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
240 443
241 } else {
242 80
243 });
244
245 let tcp = TcpStream::connect((host, port)).await?;
246 let _ = tcp.set_nodelay(true);
248
249 if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
250 #[cfg(feature = "native-tls-client")]
251 let tls = self
252 .tls_connector(http_version)
253 .connect(host, tcp)
254 .await
255 .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
256 #[cfg(feature = "rustls-client")]
257 let tls = self
258 .tls_connector(http_version)
259 .connect(
260 host.to_string()
261 .try_into()
262 .map_err(|_| Error::InvalidHost(Box::new(uri.clone())))?,
263 tcp,
264 )
265 .await
266 .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
267
268 #[cfg(feature = "native-tls-client")]
269 let is_h2 = matches!(
270 tls.get_ref()
271 .negotiated_alpn()
272 .map(|a| a.map(|b| b == b"h2")),
273 Ok(Some(true))
274 );
275
276 #[cfg(feature = "rustls-client")]
277 let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
278
279 if is_h2 {
280 let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
281 .handshake(TokioIo::new(tls))
282 .await
283 .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
284
285 tokio::spawn(conn);
286
287 Ok(SendRequest::Http2(sender))
288 } else {
289 let (sender, conn) = client::conn::http1::Builder::new()
290 .preserve_header_case(true)
291 .title_case_headers(true)
292 .handshake(TokioIo::new(tls))
293 .await
294 .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
295
296 tokio::spawn(conn.with_upgrades());
297
298 Ok(SendRequest::Http1(sender))
299 }
300 } else {
301 let (sender, conn) = client::conn::http1::Builder::new()
302 .preserve_header_case(true)
303 .title_case_headers(true)
304 .handshake(TokioIo::new(tcp))
305 .await
306 .map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
307 tokio::spawn(conn.with_upgrades());
308 Ok(SendRequest::Http1(sender))
309 }
310 }
311}
312
313enum SendRequest<B> {
314 Http1(hyper::client::conn::http1::SendRequest<B>),
315 Http2(hyper::client::conn::http2::SendRequest<B>),
316}
317
318impl<B> SendRequest<B>
319where
320 B: Body + 'static,
321{
322 async fn send_request(
323 &mut self,
324 mut req: Request<B>,
325 ) -> Result<Response<Incoming>, hyper::Error> {
326 match self {
327 SendRequest::Http1(sender) => {
328 if req.version() == hyper::Version::HTTP_2 {
329 if let Some(authority) = req.uri().authority().cloned() {
330 match authority.as_str().parse::<header::HeaderValue>() {
331 Ok(host_value) => {
332 req.headers_mut().insert(header::HOST, host_value);
333 }
334 Err(err) => {
335 tracing::warn!(
336 "Failed to parse authority '{}' as HOST header: {}",
337 authority,
338 err
339 );
340 }
341 }
342 }
343 }
344 if let Err(err) = remove_authority(&mut req) {
345 tracing::error!("Failed to remove authority from URI: {}", err);
346 }
348 sender.send_request(req).await
349 }
350 SendRequest::Http2(sender) => {
351 if req.version() != hyper::Version::HTTP_2 {
352 req.headers_mut().remove(header::HOST);
353 }
354 sender.send_request(req).await
355 }
356 }
357 }
358}
359
360impl<B> SendRequest<B> {
361 #[allow(dead_code)]
362 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
364 match self {
365 SendRequest::Http1(sender) => sender.poll_ready(cx),
366 SendRequest::Http2(sender) => sender.poll_ready(cx),
367 }
368 }
369}
370
371fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::InvalidUriParts> {
372 let mut parts = req.uri().clone().into_parts();
373 parts.scheme = None;
374 parts.authority = None;
375 *req.uri_mut() = Uri::from_parts(parts)?;
376 Ok(())
377}