hyper_rustls/connector/
builder.rs1use std::sync::Arc;
2
3use hyper_util::client::legacy::connect::HttpConnector;
4#[cfg(any(
5 feature = "rustls-native-certs",
6 feature = "rustls-platform-verifier",
7 feature = "webpki-roots"
8))]
9use rustls::crypto::CryptoProvider;
10use rustls::ClientConfig;
11
12use super::{DefaultServerNameResolver, HttpsConnector, ResolveServerName};
13#[cfg(any(
14 feature = "rustls-native-certs",
15 feature = "webpki-roots",
16 feature = "rustls-platform-verifier"
17))]
18use crate::config::ConfigBuilderExt;
19use pki_types::ServerName;
20
21pub struct ConnectorBuilder<State>(State);
42
43pub struct WantsTlsConfig(());
45
46impl ConnectorBuilder<WantsTlsConfig> {
47 pub fn new() -> Self {
49 Self(WantsTlsConfig(()))
50 }
51
52 pub fn with_tls_config(self, config: ClientConfig) -> ConnectorBuilder<WantsSchemes> {
61 assert!(
62 config.alpn_protocols.is_empty(),
63 "ALPN protocols should not be pre-defined"
64 );
65 ConnectorBuilder(WantsSchemes { tls_config: config })
66 }
67
68 #[cfg(all(
73 any(feature = "ring", feature = "aws-lc-rs"),
74 feature = "rustls-platform-verifier"
75 ))]
76 pub fn with_platform_verifier(self) -> ConnectorBuilder<WantsSchemes> {
77 self.try_with_platform_verifier()
78 .expect("failure to initialize platform verifier")
79 }
80
81 #[cfg(all(
86 any(feature = "ring", feature = "aws-lc-rs"),
87 feature = "rustls-platform-verifier"
88 ))]
89 pub fn try_with_platform_verifier(
90 self,
91 ) -> Result<ConnectorBuilder<WantsSchemes>, rustls::Error> {
92 Ok(self.with_tls_config(
93 ClientConfig::builder()
94 .try_with_platform_verifier()?
95 .with_no_client_auth(),
96 ))
97 }
98
99 #[cfg(feature = "rustls-platform-verifier")]
103 pub fn with_provider_and_platform_verifier(
104 self,
105 provider: impl Into<Arc<CryptoProvider>>,
106 ) -> std::io::Result<ConnectorBuilder<WantsSchemes>> {
107 Ok(self.with_tls_config(
108 ClientConfig::builder_with_provider(provider.into())
109 .with_safe_default_protocol_versions()
110 .and_then(|builder| builder.try_with_platform_verifier())
111 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
112 .with_no_client_auth(),
113 ))
114 }
115
116 #[cfg(all(
121 any(feature = "ring", feature = "aws-lc-rs"),
122 feature = "rustls-native-certs"
123 ))]
124 pub fn with_native_roots(self) -> std::io::Result<ConnectorBuilder<WantsSchemes>> {
125 Ok(self.with_tls_config(
126 ClientConfig::builder()
127 .with_native_roots()?
128 .with_no_client_auth(),
129 ))
130 }
131
132 #[cfg(feature = "rustls-native-certs")]
136 pub fn with_provider_and_native_roots(
137 self,
138 provider: impl Into<Arc<CryptoProvider>>,
139 ) -> std::io::Result<ConnectorBuilder<WantsSchemes>> {
140 Ok(self.with_tls_config(
141 ClientConfig::builder_with_provider(provider.into())
142 .with_safe_default_protocol_versions()
143 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
144 .with_native_roots()?
145 .with_no_client_auth(),
146 ))
147 }
148
149 #[cfg(all(any(feature = "ring", feature = "aws-lc-rs"), feature = "webpki-roots"))]
154 pub fn with_webpki_roots(self) -> ConnectorBuilder<WantsSchemes> {
155 self.with_tls_config(
156 ClientConfig::builder()
157 .with_webpki_roots()
158 .with_no_client_auth(),
159 )
160 }
161
162 #[cfg(feature = "webpki-roots")]
167 pub fn with_provider_and_webpki_roots(
168 self,
169 provider: impl Into<Arc<CryptoProvider>>,
170 ) -> Result<ConnectorBuilder<WantsSchemes>, rustls::Error> {
171 Ok(self.with_tls_config(
172 ClientConfig::builder_with_provider(provider.into())
173 .with_safe_default_protocol_versions()?
174 .with_webpki_roots()
175 .with_no_client_auth(),
176 ))
177 }
178}
179
180impl Default for ConnectorBuilder<WantsTlsConfig> {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186pub struct WantsSchemes {
189 tls_config: ClientConfig,
190}
191
192impl ConnectorBuilder<WantsSchemes> {
193 pub fn https_only(self) -> ConnectorBuilder<WantsProtocols1> {
197 ConnectorBuilder(WantsProtocols1 {
198 tls_config: self.0.tls_config,
199 https_only: true,
200 server_name_resolver: None,
201 })
202 }
203
204 pub fn https_or_http(self) -> ConnectorBuilder<WantsProtocols1> {
209 ConnectorBuilder(WantsProtocols1 {
210 tls_config: self.0.tls_config,
211 https_only: false,
212 server_name_resolver: None,
213 })
214 }
215}
216
217pub struct WantsProtocols1 {
222 tls_config: ClientConfig,
223 https_only: bool,
224 server_name_resolver: Option<Arc<dyn ResolveServerName + Sync + Send>>,
225}
226
227impl WantsProtocols1 {
228 fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
229 HttpsConnector {
230 force_https: self.https_only,
231 http: conn,
232 tls_config: std::sync::Arc::new(self.tls_config),
233 server_name_resolver: self
234 .server_name_resolver
235 .unwrap_or_else(|| Arc::new(DefaultServerNameResolver::default())),
236 }
237 }
238
239 fn build(self) -> HttpsConnector<HttpConnector> {
240 let mut http = HttpConnector::new();
241 http.enforce_http(false);
243 self.wrap_connector(http)
244 }
245}
246
247impl ConnectorBuilder<WantsProtocols1> {
248 #[cfg(feature = "http1")]
252 pub fn enable_http1(self) -> ConnectorBuilder<WantsProtocols2> {
253 ConnectorBuilder(WantsProtocols2 { inner: self.0 })
254 }
255
256 #[cfg(feature = "http2")]
260 pub fn enable_http2(mut self) -> ConnectorBuilder<WantsProtocols3> {
261 self.0.tls_config.alpn_protocols = vec![b"h2".to_vec()];
262 ConnectorBuilder(WantsProtocols3 {
263 inner: self.0,
264 enable_http1: false,
265 })
266 }
267
268 #[cfg(feature = "http2")]
273 pub fn enable_all_versions(mut self) -> ConnectorBuilder<WantsProtocols3> {
274 #[cfg(feature = "http1")]
275 let alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
276 #[cfg(not(feature = "http1"))]
277 let alpn_protocols = vec![b"h2".to_vec()];
278
279 self.0.tls_config.alpn_protocols = alpn_protocols;
280 ConnectorBuilder(WantsProtocols3 {
281 inner: self.0,
282 enable_http1: cfg!(feature = "http1"),
283 })
284 }
285
286 pub fn with_server_name_resolver(
295 mut self,
296 resolver: impl ResolveServerName + 'static + Sync + Send,
297 ) -> Self {
298 self.0.server_name_resolver = Some(Arc::new(resolver));
299 self
300 }
301
302 #[deprecated(
312 since = "0.27.1",
313 note = "use Self::with_server_name_resolver with FixedServerNameResolver instead"
314 )]
315 pub fn with_server_name(self, mut override_server_name: String) -> Self {
316 if let Some(trimmed) = override_server_name
318 .strip_prefix('[')
319 .and_then(|s| s.strip_suffix(']'))
320 {
321 override_server_name = trimmed.to_string();
322 }
323
324 self.with_server_name_resolver(move |_: &_| {
325 ServerName::try_from(override_server_name.clone())
326 })
327 }
328}
329
330pub struct WantsProtocols2 {
337 inner: WantsProtocols1,
338}
339
340impl ConnectorBuilder<WantsProtocols2> {
341 #[cfg(feature = "http2")]
345 pub fn enable_http2(mut self) -> ConnectorBuilder<WantsProtocols3> {
346 self.0.inner.tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
347 ConnectorBuilder(WantsProtocols3 {
348 inner: self.0.inner,
349 enable_http1: true,
350 })
351 }
352
353 pub fn build(self) -> HttpsConnector<HttpConnector> {
355 self.0.inner.build()
356 }
357
358 pub fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
360 self.0.inner.wrap_connector(conn)
365 }
366}
367
368#[cfg(feature = "http2")]
374pub struct WantsProtocols3 {
375 inner: WantsProtocols1,
376 #[allow(dead_code)]
378 enable_http1: bool,
379}
380
381#[cfg(feature = "http2")]
382impl ConnectorBuilder<WantsProtocols3> {
383 pub fn build(self) -> HttpsConnector<HttpConnector> {
385 self.0.inner.build()
386 }
387
388 pub fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
390 self.0.inner.wrap_connector(conn)
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 #[test]
401 #[cfg(all(feature = "webpki-roots", feature = "http1"))]
402 fn test_builder() {
403 ensure_global_state();
404 let _connector = super::ConnectorBuilder::new()
405 .with_webpki_roots()
406 .https_only()
407 .enable_http1()
408 .build();
409 }
410
411 #[test]
412 #[cfg(feature = "http1")]
413 #[should_panic(expected = "ALPN protocols should not be pre-defined")]
414 fn test_reject_predefined_alpn() {
415 ensure_global_state();
416 let roots = rustls::RootCertStore::empty();
417 let mut config_with_alpn = rustls::ClientConfig::builder()
418 .with_root_certificates(roots)
419 .with_no_client_auth();
420 config_with_alpn.alpn_protocols = vec![b"fancyprotocol".to_vec()];
421 let _connector = super::ConnectorBuilder::new()
422 .with_tls_config(config_with_alpn)
423 .https_only()
424 .enable_http1()
425 .build();
426 }
427
428 #[test]
429 #[cfg(all(feature = "http1", feature = "http2"))]
430 fn test_alpn() {
431 ensure_global_state();
432 let roots = rustls::RootCertStore::empty();
433 let tls_config = rustls::ClientConfig::builder()
434 .with_root_certificates(roots)
435 .with_no_client_auth();
436 let connector = super::ConnectorBuilder::new()
437 .with_tls_config(tls_config.clone())
438 .https_only()
439 .enable_http1()
440 .build();
441 assert!(connector
442 .tls_config
443 .alpn_protocols
444 .is_empty());
445 let connector = super::ConnectorBuilder::new()
446 .with_tls_config(tls_config.clone())
447 .https_only()
448 .enable_http2()
449 .build();
450 assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
451 let connector = super::ConnectorBuilder::new()
452 .with_tls_config(tls_config.clone())
453 .https_only()
454 .enable_http1()
455 .enable_http2()
456 .build();
457 assert_eq!(
458 &connector.tls_config.alpn_protocols,
459 &[b"h2".to_vec(), b"http/1.1".to_vec()]
460 );
461 let connector = super::ConnectorBuilder::new()
462 .with_tls_config(tls_config)
463 .https_only()
464 .enable_all_versions()
465 .build();
466 assert_eq!(
467 &connector.tls_config.alpn_protocols,
468 &[b"h2".to_vec(), b"http/1.1".to_vec()]
469 );
470 }
471
472 #[test]
473 #[cfg(all(not(feature = "http1"), feature = "http2"))]
474 fn test_alpn_http2() {
475 let roots = rustls::RootCertStore::empty();
476 let tls_config = rustls::ClientConfig::builder()
477 .with_safe_defaults()
478 .with_root_certificates(roots)
479 .with_no_client_auth();
480 let connector = super::ConnectorBuilder::new()
481 .with_tls_config(tls_config.clone())
482 .https_only()
483 .enable_http2()
484 .build();
485 assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
486 let connector = super::ConnectorBuilder::new()
487 .with_tls_config(tls_config)
488 .https_only()
489 .enable_all_versions()
490 .build();
491 assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
492 }
493
494 fn ensure_global_state() {
495 #[cfg(feature = "ring")]
496 let _ = rustls::crypto::ring::default_provider().install_default();
497 #[cfg(feature = "aws-lc-rs")]
498 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
499 }
500}