use std::{io, io::Result, sync::Arc};
use rustls::{RootCertStore, ServerConfig, pki_types::ServerName};
use rustls_pki_types::PrivateKeyDer;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, rustls::ClientConfig};
use super::{Acceptor, Connector, TlsConnectorBuilder};
#[derive(Clone)]
pub struct RustlsConnector(pub(super) TlsConnector);
#[derive(Clone)]
pub struct RustlsAcceptor(pub(super) TlsAcceptor);
impl Default for RustlsConnector {
fn default() -> Self {
Self::build(TlsConnectorBuilder::default()).expect("Failed to create RustlsConnector")
}
}
fn rustls_crypto_provider() -> rustls::crypto::CryptoProvider {
#[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
{
return rustls::crypto::aws_lc_rs::default_provider();
}
#[cfg(all(feature = "rustls-ring", not(feature = "rustls-aws-lc-rs")))]
{
return rustls::crypto::ring::default_provider();
}
#[allow(unreachable_code)]
{
panic!("only one of `rustls-aws-lc-rs` or `rustls-ring` can be selected.")
}
}
impl Connector for RustlsConnector {
fn build(builder: TlsConnectorBuilder) -> Result<Self> {
let mut certs = if builder.default_root_certs {
RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_owned(),
}
} else {
RootCertStore::empty()
};
for pem in builder.pems.into_iter() {
let cert = rustls_pemfile::certs(&mut pem.as_ref())
.next()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"No certificate found in the provided PEM",
)
})??;
certs
.add(cert)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
}
let mut client_config =
ClientConfig::builder_with_provider(Arc::new(rustls_crypto_provider()))
.with_protocol_versions(rustls::DEFAULT_VERSIONS)
.expect("something wrong on rustls ClientConfig")
.with_root_certificates(certs)
.with_no_client_auth();
client_config.alpn_protocols = builder
.alpn_protocols
.into_iter()
.map(String::into_bytes)
.collect();
let connector = TlsConnector::from(Arc::new(client_config));
Ok(Self(connector))
}
async fn connect(
&self,
server_name: &str,
tcp_stream: TcpStream,
) -> io::Result<super::TlsStream> {
let sni = ServerName::try_from(server_name)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
.to_owned();
tracing::trace!("RustlsConnector::connect({server_name:?})");
self.0
.connect(sni, tcp_stream)
.await
.map(tokio_rustls::TlsStream::Client)
.map(Into::into)
}
}
impl Acceptor for RustlsAcceptor {
fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> Result<Self> {
let cert = rustls_pemfile::certs(&mut cert.as_ref()).collect::<Result<Vec<_>>>()?;
let key = rustls_pemfile::pkcs8_private_keys(&mut key.as_ref())
.collect::<Result<Vec<_>>>()?
.pop()
.map(PrivateKeyDer::Pkcs8)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "No private key found"))?;
let server_config = ServerConfig::builder_with_provider(Arc::new(rustls_crypto_provider()))
.with_protocol_versions(rustls::DEFAULT_VERSIONS)
.expect("something wrong on rustls ServerConfig")
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
Ok(Self(acceptor))
}
async fn accept(&self, tcp_stream: TcpStream) -> Result<super::TlsStream> {
tracing::trace!("RustlsAcceptor::accept");
self.0
.accept(tcp_stream)
.await
.map(tokio_rustls::TlsStream::Server)
.map(Into::into)
}
}
impl From<ClientConfig> for super::TlsConnector {
fn from(client_config: ClientConfig) -> Self {
Self::Rustls(RustlsConnector(TlsConnector::from(Arc::new(client_config))))
}
}
impl From<Arc<ClientConfig>> for super::TlsConnector {
fn from(client_config: Arc<ClientConfig>) -> Self {
Self::Rustls(RustlsConnector(TlsConnector::from(client_config)))
}
}
impl From<TlsConnector> for super::TlsConnector {
fn from(connector: TlsConnector) -> Self {
Self::Rustls(RustlsConnector(connector))
}
}
impl From<ServerConfig> for super::TlsAcceptor {
fn from(server_config: ServerConfig) -> Self {
Self::Rustls(RustlsAcceptor(TlsAcceptor::from(Arc::new(server_config))))
}
}
impl From<Arc<ServerConfig>> for super::TlsAcceptor {
fn from(server_config: Arc<ServerConfig>) -> Self {
Self::Rustls(RustlsAcceptor(TlsAcceptor::from(server_config)))
}
}
impl From<TlsAcceptor> for super::TlsAcceptor {
fn from(acceptor: TlsAcceptor) -> Self {
Self::Rustls(RustlsAcceptor(acceptor))
}
}