1use bytes::Bytes;
2use http_body_util::Empty;
3use hyper::{
4 Request, Response, StatusCode, Uri, Version,
5 body::{Body, Incoming},
6 client, header,
7};
8use hyper_util::rt::{TokioExecutor, TokioIo};
9use std::task::{Context, Poll};
10use tokio::{net::TcpStream, task::JoinHandle};
11
12#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
13compile_error!(
14 "feature \"native-tls-client\" and feature \"rustls-client\" cannot be enabled at the same time"
15);
16
17#[cfg(all(not(feature = "native-tls-client"), not(feature = "rustls-client")))]
18compile_error!("feature \"native-tls-client\" or feature \"rustls-client\" must be enabled");
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22 #[error("{0} doesn't have an valid host")]
23 InvalidHost(Uri),
24 #[error(transparent)]
25 IoError(#[from] std::io::Error),
26 #[error(transparent)]
27 HyperError(#[from] hyper::Error),
28 #[error("Failed to connect to {0}, {1}")]
29 ConnectError(Uri, hyper::Error),
30
31 #[cfg(feature = "native-tls-client")]
32 #[error("Failed to connect with TLS to {0}, {1}")]
33 TlsConnectError(Uri, native_tls::Error),
34 #[cfg(feature = "native-tls-client")]
35 #[error(transparent)]
36 NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
37
38 #[cfg(feature = "rustls-client")]
39 #[error("Failed to connect with TLS to {0}, {1}")]
40 TlsConnectError(Uri, std::io::Error),
41}
42
43pub struct Upgraded {
45 pub client: TokioIo<hyper::upgrade::Upgraded>,
47 pub server: TokioIo<hyper::upgrade::Upgraded>,
49}
50#[derive(Clone)]
51pub struct DefaultClient {
53 #[cfg(feature = "native-tls-client")]
54 tls_connector_no_alpn: tokio_native_tls::TlsConnector,
55 #[cfg(feature = "native-tls-client")]
56 tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
57
58 #[cfg(feature = "rustls-client")]
59 tls_connector_no_alpn: tokio_rustls::TlsConnector,
60 #[cfg(feature = "rustls-client")]
61 tls_connector_alpn_h2: tokio_rustls::TlsConnector,
62
63 pub with_upgrades: bool,
66}
67impl DefaultClient {
68 #[cfg(feature = "native-tls-client")]
69 pub fn new() -> Self {
70 let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().unwrap();
71 let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
72 .request_alpns(&["h2", "http/1.1"])
73 .build()
74 .unwrap();
75
76 Self {
77 tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
78 tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
79 with_upgrades: false,
80 }
81 }
82
83 #[cfg(feature = "rustls-client")]
84 pub fn new() -> Self {
85 use std::sync::Arc;
86
87 let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
88 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
89
90 let tls_connector_no_alpn = tokio_rustls::rustls::ClientConfig::builder()
91 .with_root_certificates(root_cert_store.clone())
92 .with_no_client_auth();
93 let mut tls_connector_alpn_h2 = tokio_rustls::rustls::ClientConfig::builder()
94 .with_root_certificates(root_cert_store.clone())
95 .with_no_client_auth();
96 tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
97
98 Self {
99 tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
100 tls_connector_no_alpn,
101 )),
102 tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
103 tls_connector_alpn_h2,
104 )),
105 with_upgrades: false,
106 }
107 }
108
109 pub fn with_upgrades(mut self) -> Self {
112 self.with_upgrades = true;
113 self
114 }
115
116 #[cfg(feature = "native-tls-client")]
117 fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
118 match http_version {
119 Version::HTTP_2 => &self.tls_connector_alpn_h2,
120 _ => &self.tls_connector_no_alpn,
121 }
122 }
123
124 #[cfg(feature = "rustls-client")]
125 fn tls_connector(&self, http_version: Version) -> &tokio_rustls::TlsConnector {
126 match http_version {
127 Version::HTTP_2 => &self.tls_connector_alpn_h2,
128 _ => &self.tls_connector_no_alpn,
129 }
130 }
131
132 pub async fn send_request<B>(
136 &self,
137 req: Request<B>,
138 ) -> Result<
139 (
140 Response<Incoming>,
141 Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
142 ),
143 Error,
144 >
145 where
146 B: Body + Unpin + Send + 'static,
147 B::Data: Send,
148 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
149 {
150 let mut send_request = self.connect(req.uri(), req.version()).await?;
151
152 let (req_parts, req_body) = req.into_parts();
153
154 let res = send_request
155 .send_request(Request::from_parts(req_parts.clone(), req_body))
156 .await?;
157
158 if res.status() == StatusCode::SWITCHING_PROTOCOLS {
159 let (res_parts, res_body) = res.into_parts();
160
161 let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
162 let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
163
164 let upgrade = if self.with_upgrades {
165 Some(tokio::task::spawn(async move {
166 let client = hyper::upgrade::on(client_request).await?;
167 let server = hyper::upgrade::on(server_response).await?;
168
169 Ok(Upgraded {
170 client: TokioIo::new(client),
171 server: TokioIo::new(server),
172 })
173 }))
174 } else {
175 tokio::task::spawn(async move {
176 let client = hyper::upgrade::on(client_request).await?;
177 let server = hyper::upgrade::on(server_response).await?;
178
179 let _ = tokio::io::copy_bidirectional(
180 &mut TokioIo::new(client),
181 &mut TokioIo::new(server),
182 )
183 .await;
184
185 Ok::<_, hyper::Error>(())
186 });
187 None
188 };
189
190 Ok((Response::from_parts(res_parts, res_body), upgrade))
191 } else {
192 Ok((res, None))
193 }
194 }
195
196 async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
197 where
198 B: Body + Unpin + Send + 'static,
199 B::Data: Send,
200 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
201 {
202 let host = uri.host().ok_or_else(|| Error::InvalidHost(uri.clone()))?;
203 let port =
204 uri.port_u16()
205 .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
206 443
207 } else {
208 80
209 });
210
211 let tcp = TcpStream::connect((host, port)).await?;
212 let _ = tcp.set_nodelay(true);
214
215 if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
216 #[cfg(feature = "native-tls-client")]
217 let tls = self
218 .tls_connector(http_version)
219 .connect(host, tcp)
220 .await
221 .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
222 #[cfg(feature = "rustls-client")]
223 let tls = self
224 .tls_connector(http_version)
225 .connect(host.to_string().try_into().expect("Invalid host"), tcp)
226 .await
227 .map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
228
229 #[cfg(feature = "native-tls-client")]
230 let is_h2 = matches!(
231 tls.get_ref()
232 .negotiated_alpn()
233 .map(|a| a.map(|b| b == b"h2")),
234 Ok(Some(true))
235 );
236
237 #[cfg(feature = "rustls-client")]
238 let is_h2 = tls.get_ref().1.alpn_protocol() == Some(b"h2");
239
240 if is_h2 {
241 let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
242 .handshake(TokioIo::new(tls))
243 .await
244 .map_err(|err| Error::ConnectError(uri.clone(), err))?;
245
246 tokio::spawn(conn);
247
248 Ok(SendRequest::Http2(sender))
249 } else {
250 let (sender, conn) = client::conn::http1::Builder::new()
251 .preserve_header_case(true)
252 .title_case_headers(true)
253 .handshake(TokioIo::new(tls))
254 .await
255 .map_err(|err| Error::ConnectError(uri.clone(), err))?;
256
257 tokio::spawn(conn.with_upgrades());
258
259 Ok(SendRequest::Http1(sender))
260 }
261 } else {
262 let (sender, conn) = client::conn::http1::Builder::new()
263 .preserve_header_case(true)
264 .title_case_headers(true)
265 .handshake(TokioIo::new(tcp))
266 .await
267 .map_err(|err| Error::ConnectError(uri.clone(), err))?;
268 tokio::spawn(conn.with_upgrades());
269 Ok(SendRequest::Http1(sender))
270 }
271 }
272}
273
274enum SendRequest<B> {
275 Http1(hyper::client::conn::http1::SendRequest<B>),
276 Http2(hyper::client::conn::http2::SendRequest<B>),
277}
278
279impl<B> SendRequest<B>
280where
281 B: Body + 'static,
282{
283 async fn send_request(
284 &mut self,
285 mut req: Request<B>,
286 ) -> Result<Response<Incoming>, hyper::Error> {
287 match self {
288 SendRequest::Http1(sender) => {
289 if req.version() == hyper::Version::HTTP_2 {
290 if let Some(authority) = req.uri().authority().cloned() {
291 req.headers_mut().insert(
292 header::HOST,
293 authority.as_str().parse().expect("Invalid authority"),
294 );
295 }
296 }
297 remove_authority(&mut req);
298 sender.send_request(req).await
299 }
300 SendRequest::Http2(sender) => {
301 if req.version() != hyper::Version::HTTP_2 {
302 req.headers_mut().remove(header::HOST);
303 }
304 sender.send_request(req).await
305 }
306 }
307 }
308}
309
310impl<B> SendRequest<B> {
311 #[allow(dead_code)]
312 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
314 match self {
315 SendRequest::Http1(sender) => sender.poll_ready(cx),
316 SendRequest::Http2(sender) => sender.poll_ready(cx),
317 }
318 }
319}
320
321fn remove_authority<B>(req: &mut Request<B>) {
322 let mut parts = req.uri().clone().into_parts();
323 parts.scheme = None;
324 parts.authority = None;
325 *req.uri_mut() = Uri::from_parts(parts).unwrap();
326}