use futures::future::Either;
use futures::{FutureExt, TryFutureExt};
use http::Uri;
use std::future::Future;
use std::sync::Arc;
use thiserror::Error;
use tokio_rustls::rustls::client::ClientConfig;
use tokio_rustls::rustls::pki_types::{InvalidDnsNameError, ServerName};
use crate::Connector;
use crate::eitherio::EitherIO;
#[derive(Debug)]
enum TLSConnectorStyle {
Plain,
Tls(ServerName<'static>, Arc<ClientConfig>),
}
#[derive(Debug)]
pub struct TLSConnector<T> {
inner: T,
style: TLSConnectorStyle,
}
#[derive(Debug, Error)]
pub enum TLSConnectorCreationError {
#[error("https URI without TLS configuration")]
MissingTLSConfig,
#[error("{0}")]
InvalidNameError(#[from] InvalidDnsNameError),
}
impl<T> TLSConnector<T> {
pub fn new(
inner: T,
uri: &Uri,
config: Option<&ClientConfig>,
) -> Result<Self, TLSConnectorCreationError> {
let spiffe = if uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
false
} else if uri
.scheme()
.map(|s| s.as_str() == "spiffe")
.unwrap_or_default()
{
true
} else {
return Ok(Self {
inner,
style: TLSConnectorStyle::Plain,
});
};
let Some(c) = config else {
return Err(TLSConnectorCreationError::MissingTLSConfig);
};
let mut c = c.clone();
c.alpn_protocols = vec![b"h2".to_vec()];
let name = if spiffe {
c.enable_sni = false;
ServerName::try_from("spiffe").unwrap()
} else {
let host = uri.host().unwrap_or_default();
if host == "_" {
c.enable_sni = false;
}
ServerName::try_from(host)?
};
Ok(Self {
inner,
style: TLSConnectorStyle::Tls(name.to_owned(), Arc::new(c)),
})
}
}
#[derive(Debug, Error)]
pub enum TLSConnectorError<T: std::error::Error> {
#[error("{0}")]
InnerError(#[from] T),
#[error("{0}")]
TLSError(#[source] std::io::Error),
}
impl<A, T> Connector<A> for TLSConnector<T>
where
T: Connector<A>,
T::IO: Send + Sync + 'static,
T::Error: 'static,
{
type IO = EitherIO<T::IO, tokio_rustls::client::TlsStream<T::IO>>;
type Error = TLSConnectorError<T::Error>;
fn connect(
&self,
addr: A,
) -> impl Future<Output = Result<Self::IO, Self::Error>> + Send + Sync + 'static {
let inner = self.inner.connect(addr).err_into();
match self.style {
TLSConnectorStyle::Plain => {
Either::Left(inner.map_ok(|io| EitherIO::Left { inner: io }))
}
TLSConnectorStyle::Tls(ref name, ref config) => {
let name = name.clone();
let config = Arc::clone(config);
Either::Right(inner.and_then(move |io| {
tokio_rustls::TlsConnector::from(config)
.connect(name, io)
.map(|r| match r {
Ok(io) => Ok(EitherIO::Right { inner: io }),
Err(e) => Err(TLSConnectorError::TLSError(e)),
})
}))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_rustls::rustls::RootCertStore;
#[test]
fn without_tls() {
let uri = Uri::try_from("http://example.org").unwrap();
let c = TLSConnector::new((), &uri, None).unwrap();
assert_matches!(c.style, TLSConnectorStyle::Plain);
let uri = Uri::try_from("https://example.org").unwrap();
let e = TLSConnector::new((), &uri, None).expect_err("no tls");
assert_matches!(e, TLSConnectorCreationError::MissingTLSConfig);
let uri = Uri::try_from("spiffe://example.org").unwrap();
let e = TLSConnector::new((), &uri, None).expect_err("no tls");
assert_matches!(e, TLSConnectorCreationError::MissingTLSConfig);
let uri = Uri::try_from("unknown://example.org").unwrap();
let c = TLSConnector::new((), &uri, None).unwrap();
assert_matches!(c.style, TLSConnectorStyle::Plain);
}
#[test]
fn with_tls() {
let conf = ClientConfig::builder()
.with_root_certificates(Arc::new(RootCertStore::empty()))
.with_no_client_auth();
let uri = Uri::try_from("http://example.org").unwrap();
let c = TLSConnector::new((), &uri, Some(&conf)).unwrap();
assert_matches!(c.style, TLSConnectorStyle::Plain);
let uri = Uri::try_from("https://example.org").unwrap();
let c = TLSConnector::new((), &uri, Some(&conf)).unwrap();
match c.style {
TLSConnectorStyle::Tls(ServerName::DnsName(sn), co) => {
assert_eq!(sn.as_ref(), "example.org");
assert!(co.enable_sni);
}
_ => {
panic!("wrong style");
}
}
let uri = Uri::try_from("spiffe://example.org").unwrap();
let c = TLSConnector::new((), &uri, Some(&conf)).unwrap();
match c.style {
TLSConnectorStyle::Tls(ServerName::DnsName(sn), co) => {
assert_eq!(sn.as_ref(), "spiffe");
assert!(!co.enable_sni);
}
_ => {
panic!("wrong style");
}
}
let uri = Uri::try_from("unknown://example.org").unwrap();
let c = TLSConnector::new((), &uri, Some(&conf)).unwrap();
assert_matches!(c.style, TLSConnectorStyle::Plain);
}
}