use rustls::ClientConfig;
use super::HttpsConnector;
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
use crate::config::ConfigBuilderExt;
#[cfg(feature = "tokio-runtime")]
use hyper::client::HttpConnector;
pub struct ConnectorBuilder<State>(State);
pub struct WantsTlsConfig(());
impl ConnectorBuilder<WantsTlsConfig> {
pub fn new() -> Self {
Self(WantsTlsConfig(()))
}
pub fn with_tls_config(self, config: ClientConfig) -> ConnectorBuilder<WantsSchemes> {
assert!(
config.alpn_protocols.is_empty(),
"ALPN protocols should not be pre-defined"
);
ConnectorBuilder(WantsSchemes { tls_config: config })
}
#[cfg(feature = "rustls-native-certs")]
#[cfg_attr(docsrs, doc(cfg(feature = "rustls-native-certs")))]
pub fn with_native_roots(self) -> ConnectorBuilder<WantsSchemes> {
self.with_tls_config(
ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth(),
)
}
#[cfg(feature = "webpki-roots")]
#[cfg_attr(docsrs, doc(cfg(feature = "webpki-roots")))]
pub fn with_webpki_roots(self) -> ConnectorBuilder<WantsSchemes> {
self.with_tls_config(
ClientConfig::builder()
.with_safe_defaults()
.with_webpki_roots()
.with_no_client_auth(),
)
}
}
impl Default for ConnectorBuilder<WantsTlsConfig> {
fn default() -> Self {
Self::new()
}
}
pub struct WantsSchemes {
tls_config: ClientConfig,
}
impl ConnectorBuilder<WantsSchemes> {
pub fn https_only(self) -> ConnectorBuilder<WantsProtocols1> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: true,
override_server_name: None,
})
}
pub fn https_or_http(self) -> ConnectorBuilder<WantsProtocols1> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: false,
override_server_name: None,
})
}
}
pub struct WantsProtocols1 {
tls_config: ClientConfig,
https_only: bool,
override_server_name: Option<String>,
}
impl WantsProtocols1 {
fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
HttpsConnector {
force_https: self.https_only,
http: conn,
tls_config: std::sync::Arc::new(self.tls_config),
override_server_name: self.override_server_name,
}
}
#[cfg(feature = "tokio-runtime")]
fn build(self) -> HttpsConnector<HttpConnector> {
let mut http = HttpConnector::new();
http.enforce_http(false);
self.wrap_connector(http)
}
}
impl ConnectorBuilder<WantsProtocols1> {
#[cfg(feature = "http1")]
pub fn enable_http1(self) -> ConnectorBuilder<WantsProtocols2> {
ConnectorBuilder(WantsProtocols2 { inner: self.0 })
}
#[cfg(feature = "http2")]
#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
pub fn enable_http2(mut self) -> ConnectorBuilder<WantsProtocols3> {
self.0.tls_config.alpn_protocols = vec![b"h2".to_vec()];
ConnectorBuilder(WantsProtocols3 {
inner: self.0,
enable_http1: false,
})
}
#[cfg(feature = "http2")]
#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
pub fn enable_all_versions(mut self) -> ConnectorBuilder<WantsProtocols3> {
#[cfg(feature = "http1")]
let alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
#[cfg(not(feature = "http1"))]
let alpn_protocols = vec![b"h2".to_vec()];
self.0.tls_config.alpn_protocols = alpn_protocols;
ConnectorBuilder(WantsProtocols3 {
inner: self.0,
enable_http1: cfg!(feature = "http1"),
})
}
pub fn with_server_name(mut self, override_server_name: String) -> Self {
self.0.override_server_name = Some(override_server_name);
self
}
}
pub struct WantsProtocols2 {
inner: WantsProtocols1,
}
impl ConnectorBuilder<WantsProtocols2> {
#[cfg(feature = "http2")]
#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
pub fn enable_http2(mut self) -> ConnectorBuilder<WantsProtocols3> {
self.0.inner.tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
ConnectorBuilder(WantsProtocols3 {
inner: self.0.inner,
enable_http1: true,
})
}
#[cfg(feature = "tokio-runtime")]
pub fn build(self) -> HttpsConnector<HttpConnector> {
self.0.inner.build()
}
pub fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
self.0.inner.wrap_connector(conn)
}
}
#[cfg(feature = "http2")]
pub struct WantsProtocols3 {
inner: WantsProtocols1,
#[allow(dead_code)]
enable_http1: bool,
}
#[cfg(feature = "http2")]
impl ConnectorBuilder<WantsProtocols3> {
#[cfg(feature = "tokio-runtime")]
pub fn build(self) -> HttpsConnector<HttpConnector> {
self.0.inner.build()
}
pub fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
self.0.inner.wrap_connector(conn)
}
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(all(feature = "webpki-roots", feature = "http1"))]
fn test_builder() {
let _connector = super::ConnectorBuilder::new()
.with_webpki_roots()
.https_only()
.enable_http1()
.build();
}
#[test]
#[cfg(feature = "http1")]
#[should_panic(expected = "ALPN protocols should not be pre-defined")]
fn test_reject_predefined_alpn() {
let roots = rustls::RootCertStore::empty();
let mut config_with_alpn = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
config_with_alpn.alpn_protocols = vec![b"fancyprotocol".to_vec()];
let _connector = super::ConnectorBuilder::new()
.with_tls_config(config_with_alpn)
.https_only()
.enable_http1()
.build();
}
#[test]
#[cfg(all(feature = "http1", feature = "http2"))]
fn test_alpn() {
let roots = rustls::RootCertStore::empty();
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
let connector = super::ConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.https_only()
.enable_http1()
.build();
assert!(connector
.tls_config
.alpn_protocols
.is_empty());
let connector = super::ConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.https_only()
.enable_http2()
.build();
assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
let connector = super::ConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.https_only()
.enable_http1()
.enable_http2()
.build();
assert_eq!(
&connector.tls_config.alpn_protocols,
&[b"h2".to_vec(), b"http/1.1".to_vec()]
);
let connector = super::ConnectorBuilder::new()
.with_tls_config(tls_config)
.https_only()
.enable_all_versions()
.build();
assert_eq!(
&connector.tls_config.alpn_protocols,
&[b"h2".to_vec(), b"http/1.1".to_vec()]
);
}
#[test]
#[cfg(all(not(feature = "http1"), feature = "http2"))]
fn test_alpn_http2() {
let roots = rustls::RootCertStore::empty();
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
let connector = super::ConnectorBuilder::new()
.with_tls_config(tls_config.clone())
.https_only()
.enable_http2()
.build();
assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
let connector = super::ConnectorBuilder::new()
.with_tls_config(tls_config)
.https_only()
.enable_all_versions()
.build();
assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
}
}