trz-gateway-common 0.2.9

Secure Proxy / Agents implementation in Rust
Documentation
use std::sync::Arc;
use std::time::SystemTime;

use nameth::NamedEnumValues as _;
use nameth::nameth;
use openssl::error::ErrorStack;
use openssl::pkey::PKey;
use openssl::pkey::Private;
use openssl::x509::X509;
use rustls::ServerConfig;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivateKeyDer;
use rustls::server::ClientHello;
use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use tracing::Level;
use tracing::debug;
use tracing::info;
use tracing::info_span;
use tracing::warn;

use super::CertificateConfig;
use crate::certificate_info::CertificateInfo;
use crate::certificate_info::X509CertificateInfo;
use crate::crypto_provider::crypto_provider;
use crate::security_configuration::certificate::display_x509_certificate;
use crate::x509::time::asn1_to_system_time;

/// Create a RusTLS [ServerConfig] from a [CertificateConfig].
pub trait ToTlsServer: CertificateConfig + Sized {
    fn to_tls_server(self) -> Result<Arc<ServerConfig>, ToTlsServerError<Self::Error>> {
        to_tls_server_impl(self)
    }
}

impl<T: CertificateConfig> ToTlsServer for T {}

fn to_tls_server_impl<T: CertificateConfig>(
    certificate_config: T,
) -> Result<Arc<ServerConfig>, ToTlsServerError<T::Error>> {
    let _span = info_span!("Setup TLS server certificate").entered();
    crypto_provider();
    let server_config = ServerConfig::builder().with_no_client_auth();
    let mut server_config = if certificate_config.is_dynamic() {
        server_config.with_cert_resolver(Arc::new(ServerCertificateResolver {
            state: Default::default(),
            certificate_config,
        }))
    } else {
        let (certificate_chain, private_key) = build_single_cert::<T>(
            &*certificate_config
                .certificate()
                .map_err(ToTlsServerError::Certificate)?,
            &certificate_config
                .intermediates()
                .map_err(ToTlsServerError::Intermediates)?,
        )?;
        server_config
            .with_single_cert(certificate_chain, private_key)
            .map_err(ToTlsServerError::ServerConfig)?
    };
    server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
    Ok(Arc::new(server_config))
}

fn build_single_cert<T: CertificateConfig>(
    certificate: &X509CertificateInfo,
    intermediates: &[X509],
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), ToTlsServerError<T::Error>> {
    let mut certificate_chain = vec![];
    {
        log_server_certiticate(certificate);
        let certificate = certificate.certificate.to_der();
        let certificate = certificate.map_err(ToTlsServerError::CertificateToDer)?;
        certificate_chain.push(certificate.into());
    }
    for intermediate in intermediates.iter() {
        info!(
            "Intermediate certificate: {:?} issued by {:?}",
            intermediate.subject_name(),
            intermediate.issuer_name()
        );
        debug!(
            "Intermediate certificate details: {}",
            display_x509_certificate(intermediate)
        );
        let intermediate = intermediate.to_der();
        let intermediate = intermediate.map_err(ToTlsServerError::IntermediateToDer)?;
        certificate_chain.push(intermediate.into());
    }

    let private_key = certificate
        .private_key
        .private_key_to_der()
        .map_err(ToTlsServerError::PrivateKeyToDer)?
        .try_into()
        .map_err(ToTlsServerError::ToPrivateKey)?;

    Ok((certificate_chain, private_key))
}

struct ServerCertificateResolver<T> {
    certificate_config: T,
    state: std::sync::Mutex<Option<CertResolverState>>,
}

struct CertResolverState {
    certified_key: Arc<CertifiedKey>,
    certificate: Arc<X509CertificateInfo>,
    intermediates: Arc<Vec<X509>>,
}

impl<T> std::fmt::Debug for ServerCertificateResolver<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CertResolver").finish()
    }
}

impl<T: CertificateConfig> ResolvesServerCert for ServerCertificateResolver<T> {
    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
        let _span = info_span!(
            "Resolve server certificate",
            host = client_hello.server_name()
        )
        .entered();
        let mut state = self.state.lock().unwrap();
        match self.resolve_impl(&mut state) {
            Ok(certified_key) => Some(certified_key),
            Err(error) => {
                warn!("Failed to resolve server certificate: {error}");
                if let Some(state) = &*state {
                    info!("Reuse stale cached server certificate");
                    Some(state.certified_key.clone())
                } else {
                    None
                }
            }
        }
    }
}

impl<T: CertificateConfig> ServerCertificateResolver<T> {
    fn resolve_impl(
        &self,
        state: &mut Option<CertResolverState>,
    ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
        let certificate = self
            .certificate_config
            .certificate()
            .map_err(ToTlsServerError::Certificate)?;
        let intermediates = self
            .certificate_config
            .intermediates()
            .map_err(ToTlsServerError::Intermediates)?;

        if let Some(state) = state
            && Arc::ptr_eq(&certificate, &state.certificate)
            && Arc::ptr_eq(&intermediates, &state.intermediates)
        {
            debug!("Reuse cached server certificate");
            return Ok(state.certified_key.clone());
        }

        log_server_certiticate(&certificate);
        let certified_key = self.make_certified_key(&certificate, &intermediates)?;
        *state = Some(CertResolverState {
            certified_key: certified_key.clone(),
            certificate,
            intermediates,
        });
        return Ok(certified_key);
    }

    fn make_certified_key(
        &self,
        certificate: &X509CertificateInfo,
        intermediates: &[X509],
    ) -> Result<Arc<CertifiedKey>, ToTlsServerError<T::Error>> {
        let (certificate_chain, private_key) = build_single_cert::<T>(certificate, intermediates)?;
        let certified_key =
            CertifiedKey::from_der(certificate_chain, private_key, crypto_provider())
                .map_err(ToTlsServerError::CertifiedKey)?;
        Ok(Arc::new(certified_key))
    }
}

#[nameth]
#[derive(thiserror::Error, Debug)]
pub enum ToTlsServerError<E: std::error::Error> {
    #[error("[{n}] {0}", n = self.name())]
    Certificate(E),

    #[error("[{n}] {0}", n = self.name())]
    CertificateToDer(ErrorStack),

    #[error("[{n}] {0}", n = self.name())]
    Intermediates(E),

    #[error("[{n}] {0}", n = self.name())]
    IntermediateToDer(ErrorStack),

    #[error("[{n}] {0}", n = self.name())]
    PrivateKeyToDer(ErrorStack),

    #[error("[{n}] {0}", n = self.name())]
    ToPrivateKey(&'static str),

    #[error("[{n}] {0}", n = self.name())]
    ServerConfig(rustls::Error),

    #[error("[{n}] {0}", n = self.name())]
    CertifiedKey(rustls::Error),
}

fn log_server_certiticate(certificate: &CertificateInfo<X509, PKey<Private>>) {
    if !tracing::enabled!(Level::INFO) {
        return;
    }
    let now = SystemTime::now();
    let subject = certificate.certificate.subject_name();
    let issuer = certificate.certificate.issuer_name();
    let not_after = certificate.certificate.not_after();
    let expiration =
        match asn1_to_system_time(not_after).map(|not_after| not_after.duration_since(now)) {
            Ok(Ok(expiration)) => humantime::format_duration(expiration).to_string(),
            Err(error) => format!("Err: {error}"),
            Ok(Err(error)) => format!("Err: {error}"),
        };
    info! { "Server certificate: {subject:?} issued by {issuer:?} expires {not_after} ({expiration})" };
    debug!("Server certificate details: {}", certificate.display());
}