volo 0.12.3

Volo is a high-performance and strong-extensibility Rust RPC framework that helps developers build microservices.
Documentation
use std::{io, io::Result, sync::Arc};

use rustls::{RootCertStore, ServerConfig, pki_types::ServerName};
use rustls_pki_types::PrivateKeyDer;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, rustls::ClientConfig};

use super::{Acceptor, Connector, TlsConnectorBuilder};

/// A wrapper for [`tokio_rustls::TlsConnector`]
#[derive(Clone)]
pub struct RustlsConnector(pub(super) TlsConnector);

/// A wrapper for [`tokio_rustls::TlsAcceptor`]
#[derive(Clone)]
pub struct RustlsAcceptor(pub(super) TlsAcceptor);

impl Default for RustlsConnector {
    fn default() -> Self {
        Self::build(TlsConnectorBuilder::default()).expect("Failed to create RustlsConnector")
    }
}

fn rustls_crypto_provider() -> rustls::crypto::CryptoProvider {
    #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
    {
        return rustls::crypto::aws_lc_rs::default_provider();
    }

    #[cfg(all(feature = "rustls-ring", not(feature = "rustls-aws-lc-rs")))]
    {
        return rustls::crypto::ring::default_provider();
    }

    #[allow(unreachable_code)]
    {
        panic!("only one of `rustls-aws-lc-rs` or `rustls-ring` can be selected.")
    }
}

impl Connector for RustlsConnector {
    fn build(builder: TlsConnectorBuilder) -> Result<Self> {
        let mut certs = if builder.default_root_certs {
            RootCertStore {
                roots: webpki_roots::TLS_SERVER_ROOTS.to_owned(),
            }
        } else {
            RootCertStore::empty()
        };
        for pem in builder.pems.into_iter() {
            let cert = rustls_pemfile::certs(&mut pem.as_ref())
                .next()
                .ok_or_else(|| {
                    io::Error::new(
                        io::ErrorKind::InvalidInput,
                        "No certificate found in the provided PEM",
                    )
                })??;
            certs
                .add(cert)
                .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
        }
        let mut client_config =
            ClientConfig::builder_with_provider(Arc::new(rustls_crypto_provider()))
                .with_protocol_versions(rustls::DEFAULT_VERSIONS)
                .expect("something wrong on rustls ClientConfig")
                .with_root_certificates(certs)
                .with_no_client_auth();
        client_config.alpn_protocols = builder
            .alpn_protocols
            .into_iter()
            .map(String::into_bytes)
            .collect();
        let connector = TlsConnector::from(Arc::new(client_config));
        Ok(Self(connector))
    }

    async fn connect(
        &self,
        server_name: &str,
        tcp_stream: TcpStream,
    ) -> io::Result<super::TlsStream> {
        let sni = ServerName::try_from(server_name)
            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
            .to_owned();
        tracing::trace!("RustlsConnector::connect({server_name:?})");
        self.0
            .connect(sni, tcp_stream)
            .await
            .map(tokio_rustls::TlsStream::Client)
            .map(Into::into)
    }
}

impl Acceptor for RustlsAcceptor {
    fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> Result<Self> {
        let cert = rustls_pemfile::certs(&mut cert.as_ref()).collect::<Result<Vec<_>>>()?;
        let key = rustls_pemfile::pkcs8_private_keys(&mut key.as_ref())
            .collect::<Result<Vec<_>>>()?
            .pop()
            .map(PrivateKeyDer::Pkcs8)
            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "No private key found"))?;
        let server_config = ServerConfig::builder_with_provider(Arc::new(rustls_crypto_provider()))
            .with_protocol_versions(rustls::DEFAULT_VERSIONS)
            .expect("something wrong on rustls ServerConfig")
            .with_no_client_auth()
            .with_single_cert(cert, key)
            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
        let acceptor = TlsAcceptor::from(Arc::new(server_config));
        Ok(Self(acceptor))
    }

    async fn accept(&self, tcp_stream: TcpStream) -> Result<super::TlsStream> {
        tracing::trace!("RustlsAcceptor::accept");
        self.0
            .accept(tcp_stream)
            .await
            .map(tokio_rustls::TlsStream::Server)
            .map(Into::into)
    }
}

impl From<ClientConfig> for super::TlsConnector {
    fn from(client_config: ClientConfig) -> Self {
        Self::Rustls(RustlsConnector(TlsConnector::from(Arc::new(client_config))))
    }
}

impl From<Arc<ClientConfig>> for super::TlsConnector {
    fn from(client_config: Arc<ClientConfig>) -> Self {
        Self::Rustls(RustlsConnector(TlsConnector::from(client_config)))
    }
}

impl From<TlsConnector> for super::TlsConnector {
    fn from(connector: TlsConnector) -> Self {
        Self::Rustls(RustlsConnector(connector))
    }
}

impl From<ServerConfig> for super::TlsAcceptor {
    fn from(server_config: ServerConfig) -> Self {
        Self::Rustls(RustlsAcceptor(TlsAcceptor::from(Arc::new(server_config))))
    }
}

impl From<Arc<ServerConfig>> for super::TlsAcceptor {
    fn from(server_config: Arc<ServerConfig>) -> Self {
        Self::Rustls(RustlsAcceptor(TlsAcceptor::from(server_config)))
    }
}

impl From<TlsAcceptor> for super::TlsAcceptor {
    fn from(acceptor: TlsAcceptor) -> Self {
        Self::Rustls(RustlsAcceptor(acceptor))
    }
}