1#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
2
3use bytes::Bytes;
4use http_body_util::Empty;
5use hyper::{
6 Request, Response, StatusCode, Uri, Version,
7 body::{Body, Incoming},
8 client, header,
9};
10use hyper_util::rt::{TokioExecutor, TokioIo};
11use std::task::{Context, Poll};
12use tokio::{net::TcpStream, task::JoinHandle};
13
14#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
15compile_error!(
16 "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
17);
18
19#[derive(thiserror::Error, Debug)]
20pub enum Error {
21 #[error("{0} doesn't have an valid host")]
22 InvalidHost(Uri),
23 #[error(transparent)]
24 IoError(#[from] std::io::Error),
25 #[error(transparent)]
26 HyperError(#[from] hyper::Error),
27 #[error("Failed to connect to {0}, {1}")]
28 ConnectError(Uri, hyper::Error),
29
30 #[cfg(feature = "native-tls-client")]
31 #[error("Failed to connect with TLS to {0}, {1}")]
32 TlsConnectError(Uri, native_tls::Error),
33 #[cfg(feature = "native-tls-client")]
34 #[error(transparent)]
35 NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
36
37 #[cfg(feature = "rustls-client")]
38 #[error("Failed to connect with TLS to {0}, {1}")]
39 TlsConnectError(Uri, std::io::Error),
40}
41
42pub struct Upgraded {
44 pub client: TokioIo<hyper::upgrade::Upgraded>,
46 pub server: TokioIo<hyper::upgrade::Upgraded>,
48}
49#[derive(Clone)]
50pub struct DefaultClient {
52 #[cfg(feature = "native-tls-client")]
53 tls_connector_no_alpn: tokio_native_tls::TlsConnector,
54 #[cfg(feature = "native-tls-client")]
55 tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
56
57 #[cfg(feature = "rustls-client")]
58 tls_connector_no_alpn: tokio_rustls::TlsConnector,
59 #[cfg(feature = "rustls-client")]
60 tls_connector_alpn_h2: tokio_rustls::TlsConnector,
61
62 pub with_upgrades: bool,
65}
66impl DefaultClient {
67 #[cfg(feature = "native-tls-client")]
68 pub fn new() -> Self {
69 let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().unwrap();
70 let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
71 .request_alpns(&["h2", "http/1.1"])
72 .build()
73 .unwrap();
74
75 Self {
76 tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
77 tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
78 with_upgrades: false,
79 }
80 }
81
82 #[cfg(feature = "rustls-client")]
83 pub fn new() -> Self {
84 use std::sync::Arc;
85
86 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
87 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
88
89 let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
90 .with_root_certificates(root_cert_store.clone())
91 .with_no_client_auth();
92 let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
93 .with_root_certificates(root_cert_store.clone())
94 .with_no_client_auth();
95 tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
96
97 Self {
98 tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
99 tls_connector_no_alpn,
100 )),
101 tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
102 tls_connector_alpn_h2,
103 )),
104 with_upgrades: false,
105 }
106 }
107
108 pub fn with_upgrades(mut self) -> Self {
111 self.with_upgrades = true;
112 self
113 }
114
115 #[cfg(feature = "native-tls-client")]
116 fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
117 match http_version {
118 Version::HTTP_2 => &self.tls_connector_alpn_h2,
119 _ => &self.tls_connector_no_alpn,
120 }
121 }
122
123 #[cfg(feature = "rustls-client")]
124 fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
125 match http_version {
126 Version::HTTP_2 => &self.tls_connector_alpn_h2,
127 _ => &self.tls_connector_no_alpn,
128 }
129 }
130
131 pub async fn send_request<B>(
135 &self,
136 req: Request<B>,
137 ) -> Result<
138 (
139 Response<Incoming>,
140 Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
141 ),
142 Error,
143 >
144 where
145 B: Body + Unpin + Send + 'static,
146 B::Data: Send,
147 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
148 {
149 let mut send_request = self.connect(req.uri(), req.version()).await?;
150
151 let (req_parts, req_body) = req.into_parts();
152
153 let res = send_request
154 .send_request(Request::from_parts(req_parts.clone(), req_body))
155 .await?;
156
157 if res.status() == StatusCode::SWITCHING_PROTOCOLS {
158 let (res_parts, res_body) = res.into_parts();
159
160 let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
161 let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
162
163 let upgrade = if self.with_upgrades {
164 Some(tokio::task::spawn(async move {
165 let client = hyper::upgrade::on(client_request).await?;
166 let server = hyper::upgrade::on(server_response).await?;
167
168 Ok(Upgraded {
169 client: TokioIo::new(client),
170 server: TokioIo::new(server),
171 })
172 }))
173 } else {
174 tokio::task::spawn(async move {
175 let client = hyper::upgrade::on(client_request).await?;
176 let server = hyper::upgrade::on(server_response).await?;
177
178 let _ = tokio::io::copy_bidirectional(
179 &mut TokioIo::new(client),
180 &mut TokioIo::new(server),
181 )
182 .await;
183
184 Ok::<_, hyper::Error>(())
185 });
186 None
187 };
188
189 Ok((Response::from_parts(res_parts, res_body), upgrade))
190 } else {
191 Ok((res, None))
192 }
193 }
194
195 async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
196 where
197 B: Body + Unpin + Send + 'static,
198 B::Data: Send,
199 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
200 {
201 let host = uri.host().ok_or_else(|| Error::InvalidHost(uri.clone()))?;
202 let port =
203 uri.port_u16()
204 .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
205 443
206 } else {
207 80
208 });
209
210 let tcp = TcpStream::connect((host, port)).await?;
211 let _ = tcp.set_nodelay(true);
213
214 if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
215 #[cfg(feature = "native-tls-client")]
216 let tls = self
217 .tls_connector(http_version)
218 .connect(host, tcp)
219 .await
220 .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
221 #[cfg(feature = "rustls-client")]
222 let tls = self
223 .tls_connector(http_version)
224 .connect(host.to_string().try_into().expect("Invalid host"), tcp)
225 .await
226 .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
227
228 #[cfg(feature = "native-tls-client")]
229 let is_h2 = matches!(
230 tls.get_ref()
231 .negotiated_alpn()
232 .map(|a| a.map(|b| b == b"h2")),
233 Ok(Some(true))
234 );
235
236 #[cfg(feature = "rustls-client")]
237 let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
238
239 if is_h2 {
240 let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
241 .handshake(TokioIo::new(tls))
242 .await
243 .map_err(|err| Error::ConnectError(uri.clone(), err))?;
244
245 tokio::spawn(conn);
246
247 Ok(SendRequest::Http2(sender))
248 } else {
249 let (sender, conn) = client::conn::http1::Builder::new()
250 .preserve_header_case(true)
251 .title_case_headers(true)
252 .handshake(TokioIo::new(tls))
253 .await
254 .map_err(|err| Error::ConnectError(uri.clone(), err))?;
255
256 tokio::spawn(conn.with_upgrades());
257
258 Ok(SendRequest::Http1(sender))
259 }
260 } else {
261 let (sender, conn) = client::conn::http1::Builder::new()
262 .preserve_header_case(true)
263 .title_case_headers(true)
264 .handshake(TokioIo::new(tcp))
265 .await
266 .map_err(|err| Error::ConnectError(uri.clone(), err))?;
267 tokio::spawn(conn.with_upgrades());
268 Ok(SendRequest::Http1(sender))
269 }
270 }
271}
272
273enum SendRequest<B> {
274 Http1(hyper::client::conn::http1::SendRequest<B>),
275 Http2(hyper::client::conn::http2::SendRequest<B>),
276}
277
278impl<B> SendRequest<B>
279where
280 B: Body + 'static,
281{
282 async fn send_request(
283 &mut self,
284 mut req: Request<B>,
285 ) -> Result<Response<Incoming>, hyper::Error> {
286 match self {
287 SendRequest::Http1(sender) => {
288 if req.version() == hyper::Version::HTTP_2 {
289 if let Some(authority) = req.uri().authority().cloned() {
290 req.headers_mut().insert(
291 header::HOST,
292 authority.as_str().parse().expect("Invalid authority"),
293 );
294 }
295 }
296 remove_authority(&mut req);
297 sender.send_request(req).await
298 }
299 SendRequest::Http2(sender) => {
300 if req.version() != hyper::Version::HTTP_2 {
301 req.headers_mut().remove(header::HOST);
302 }
303 sender.send_request(req).await
304 }
305 }
306 }
307}
308
309impl<B> SendRequest<B> {
310 #[allow(dead_code)]
311 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
313 match self {
314 SendRequest::Http1(sender) => sender.poll_ready(cx),
315 SendRequest::Http2(sender) => sender.poll_ready(cx),
316 }
317 }
318}
319
320fn remove_authority<B>(req: &mut Request<B>) {
321 let mut parts = req.uri().clone().into_parts();
322 parts.scheme = None;
323 parts.authority = None;
324 *req.uri_mut() = Uri::from_parts(parts).unwrap();
325}