rsub 0.1.0

A high-performance message broker with QUIC transport and pub/sub messaging patterns
Documentation
use std::{fs, path::Path, sync::Arc};

use anyhow::{anyhow, Result};
use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig};
use rcgen::{self, date_time_ymd, CertificateParams, DistinguishedName, DnType, KeyPair, SanType};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};

use crate::server::config::TlsConfig;

use super::ALPN_QUIC_HTTP;

pub fn create_server_config(tls_config: &TlsConfig) -> Result<(quinn::ServerConfig, Vec<u8>)> {
    rustls::crypto::ring::default_provider()
        .install_default()
        .map_err(|_| anyhow!("Failed to install rustls crypto provider"))?;

    let (key_path, cert_path) = match (&tls_config.key_file, &tls_config.cert_file) {
        (Some(key), Some(cert)) => (key, cert),
        _ => return Err(anyhow!("TLS configuration is missing")),
    };

    let key = fs::read(key_path).map_err(|e| anyhow!("Failed to read private key: {}", e))?;
    let key = if Path::new(key_path)
        .file_name()
        .and_then(|s| s.to_str())
        .map_or(false, |s| s.ends_with(".der"))
    {
        PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key))
    } else {
        rustls_pemfile::private_key(&mut &*key)
            .map_err(|_| anyhow!("Malformed PKCS #1 private key"))?
            .ok_or_else(|| anyhow!("No private keys found"))?
    };

    let cert_chain =
        fs::read(cert_path).map_err(|e| anyhow!("Failed to read certificate chain: {}", e))?;
    let cert_chain = if Path::new(cert_path)
        .file_name()
        .and_then(|s| s.to_str())
        .map_or(false, |s| s.ends_with(".der"))
    {
        vec![CertificateDer::from(cert_chain)]
    } else {
        rustls_pemfile::certs(&mut &*cert_chain)
            .collect::<Result<_, _>>()
            .map_err(|_| anyhow!("Invalid PEM-encoded certificate"))?
    };

    let mut server_crypto = rustls::ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(cert_chain.clone(), key)
        .map_err(|e| match e {
            rustls::Error::InvalidCertificate(rustls::CertificateError::Other(ref other))
                if other.to_string().contains("UnsupportedCertVersion") =>
            {
                anyhow!("Unsupported certificate version")
            }
            _ => anyhow!("Failed to create server config: {}", e),
        })?;
    server_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect();
    let server_crypto = Arc::new(server_crypto);
    let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
        QuicServerConfig::try_from(server_crypto)
            .map_err(|e| anyhow!("Failed to create QuicServerConfig: {}", e))?,
    ));
    let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
    transport_config.max_concurrent_uni_streams(0_u8.into());
    let server_cert = cert_chain[0].as_ref().to_vec();

    Ok((server_config, server_cert))
}

pub fn create_client_config(ca_path: &str) -> Result<quinn::ClientConfig> {
    rustls::crypto::ring::default_provider()
        .install_default()
        .map_err(|_| anyhow!("Failed to install rustls crypto provider"))?;

    let mut roots = rustls::RootCertStore::empty();
    roots
        .add(CertificateDer::from(fs::read(ca_path)?))
        .map_err(|e| anyhow!("Failed to create root cert store: {}", e))?;

    let mut client_crypto = rustls::ClientConfig::builder()
        .with_root_certificates(roots)
        .with_no_client_auth();

    client_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect();

    Ok(quinn::ClientConfig::new(Arc::new(
        QuicClientConfig::try_from(client_crypto)
            .map_err(|e| anyhow!("Failed to create QuicClientConfig: {}", e))?,
    )))
}

pub fn generate_self_signed_cert(
    key_path: &str,
    cert_path: &str,
    params: Option<CertificateParams>,
) -> Result<()> {
    let mut params = params.unwrap_or_else(|| {
        let mut p: CertificateParams = Default::default();
        p.not_before = date_time_ymd(1975, 1, 1);
        p.not_after = date_time_ymd(4096, 1, 1);
        p.distinguished_name = DistinguishedName::new();
        p.distinguished_name
            .push(DnType::OrganizationName, "Crab widgits SE");
        p.distinguished_name.push(DnType::CommonName, "Master Cert");
        p.subject_alt_names = vec![
            SanType::DnsName("crabs.crabs".try_into().unwrap()),
            SanType::DnsName("localhost".try_into().unwrap()),
        ];
        p
    });

    // Ensure we always have localhost as a SAN
    if !params
        .subject_alt_names
        .iter()
        .any(|san| matches!(san, SanType::DnsName(name) if name == "localhost"))
    {
        params
            .subject_alt_names
            .push(SanType::DnsName("localhost".try_into().unwrap()));
    }

    let key_pair =
        KeyPair::generate().map_err(|e| anyhow!("Failed to generate key pair: {}", e))?;
    let cert = params
        .self_signed(&key_pair)
        .map_err(|e| anyhow!("Failed to generate self-signed certificate: {}", e))?;

    let pem_serialized = cert.pem();
    let der_serialized = cert.der();

    fs::create_dir_all(Path::new(cert_path).parent().unwrap())?;
    fs::create_dir_all(Path::new(key_path).parent().unwrap())?;

    fs::write(cert_path, pem_serialized.as_bytes())?;
    fs::write(cert_path.replace(".pem", ".der"), der_serialized)?;
    fs::write(key_path, key_pair.serialize_pem().as_bytes())?;
    fs::write(key_path.replace(".pem", ".der"), key_pair.serialize_der())?;

    Ok(())
}