use crate::error::{Error, ErrorKind};
use std::{
fmt,
fmt::{Debug, Formatter},
net::IpAddr,
sync::Arc,
};
#[cfg(feature = "enable-native-tls")]
use std::convert::{TryFrom, TryInto};
#[cfg(feature = "enable-native-tls")]
use tokio_native_tls::native_tls::{
TlsConnector as NativeTlsConnector,
TlsConnectorBuilder as NativeTlsConnectorBuilder,
};
#[cfg(feature = "enable-native-tls")]
use tokio_native_tls::TlsConnector as TokioNativeTlsConnector;
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
use tokio_rustls::rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
use tokio_rustls::TlsConnector as RustlsConnector;
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))))]
pub trait HostMapping: Send + Sync + Debug {
fn map(&self, ip: &IpAddr, default_host: &str) -> Option<String>;
}
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))))]
#[derive(Clone, Debug)]
pub enum TlsHostMapping {
None,
DefaultHost,
Custom(Arc<dyn HostMapping>),
}
impl TlsHostMapping {
pub(crate) fn map(&self, value: &IpAddr, default_host: &str) -> Option<String> {
match self {
TlsHostMapping::None => None,
TlsHostMapping::DefaultHost => Some(default_host.to_owned()),
TlsHostMapping::Custom(ref inner) => inner.map(value, default_host),
}
}
}
impl PartialEq for TlsHostMapping {
fn eq(&self, other: &Self) -> bool {
match self {
TlsHostMapping::None => matches!(other, TlsHostMapping::None),
TlsHostMapping::DefaultHost => matches!(other, TlsHostMapping::DefaultHost),
TlsHostMapping::Custom(_) => matches!(other, TlsHostMapping::Custom(_)),
}
}
}
impl Eq for TlsHostMapping {}
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TlsConfig {
pub connector: TlsConnector,
pub hostnames: TlsHostMapping,
}
impl<C: Into<TlsConnector>> From<C> for TlsConfig {
fn from(connector: C) -> Self {
TlsConfig {
connector: connector.into(),
hostnames: TlsHostMapping::None,
}
}
}
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))))]
#[derive(Clone)]
pub enum TlsConnector {
#[cfg(feature = "enable-native-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "enable-native-tls")))]
Native(TokioNativeTlsConnector),
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))))]
Rustls(RustlsConnector),
}
impl PartialEq for TlsConnector {
fn eq(&self, _: &Self) -> bool {
true
}
}
impl Eq for TlsConnector {}
impl Debug for TlsConnector {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector")
.field("kind", match self {
#[cfg(feature = "enable-native-tls")]
TlsConnector::Native(_) => &"Native",
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
TlsConnector::Rustls(_) => &"Rustls",
})
.finish()
}
}
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-native-tls", feature = "enable-rustls"))))]
impl TlsConnector {
#[cfg(feature = "enable-native-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "enable-native-tls")))]
pub fn default_native_tls() -> Result<Self, Error> {
NativeTlsConnector::builder().try_into()
}
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))))]
pub fn default_rustls() -> Result<Self, Error> {
let mut system_certs = rustls_native_certs::load_native_certs();
if !system_certs.errors.is_empty() {
return Err(Error::new(
ErrorKind::Tls,
format!("{:?}", system_certs.errors.pop().unwrap()),
));
}
let mut cert_store = RootCertStore::empty();
for system_cert in system_certs.certs.into_iter() {
cert_store.add(system_cert)?;
}
Ok(
RustlsClientConfig::builder()
.with_root_certificates(cert_store)
.with_no_client_auth()
.into(),
)
}
}
#[cfg(feature = "enable-native-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "enable-native-tls")))]
impl TryFrom<NativeTlsConnectorBuilder> for TlsConnector {
type Error = Error;
fn try_from(builder: NativeTlsConnectorBuilder) -> Result<Self, Self::Error> {
let connector = builder
.build()
.map(TokioNativeTlsConnector::from)
.map_err(|e| Error::new(ErrorKind::Tls, format!("{:?}", e)))?;
Ok(TlsConnector::Native(connector))
}
}
#[cfg(feature = "enable-native-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "enable-native-tls")))]
impl From<NativeTlsConnector> for TlsConnector {
fn from(connector: NativeTlsConnector) -> Self {
TlsConnector::Native(TokioNativeTlsConnector::from(connector))
}
}
#[cfg(feature = "enable-native-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "enable-native-tls")))]
impl From<TokioNativeTlsConnector> for TlsConnector {
fn from(connector: TokioNativeTlsConnector) -> Self {
TlsConnector::Native(connector)
}
}
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))))]
impl From<RustlsClientConfig> for TlsConnector {
fn from(config: RustlsClientConfig) -> Self {
TlsConnector::Rustls(RustlsConnector::from(Arc::new(config)))
}
}
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))))]
impl From<RustlsConnector> for TlsConnector {
fn from(connector: RustlsConnector) -> Self {
TlsConnector::Rustls(connector)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(any(feature = "enable-rustls", feature = "enable-rustls-ring"))]
fn should_create_default_rustls() {
let _ = TlsConnector::default_rustls().unwrap();
}
#[test]
#[cfg(feature = "enable-native-tls")]
fn should_create_default_native_tls() {
let _ = TlsConnector::default_native_tls().unwrap();
}
}