rusty-penguin 0.8.3

A fast TCP/UDP tunnel, transported over HTTP WebSocket
Documentation
//! TLS-related code for `native-tls`.
//
// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later

use super::Error;
use tokio_native_tls::native_tls::{Certificate, Identity, TlsAcceptor, TlsConnector};

/// Type alias for the inner TLS identity type.
pub type TlsIdentityInner = tokio_native_tls::TlsAcceptor;

/// Type alias for the Hyper HTTPS connector.
#[cfg(feature = "server")]
pub type HyperConnector =
    hyper_tls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>;

/// Create a `TlsAcceptor` from the given certificate
pub async fn make_server_config(
    cert_path: &str,
    key_path: &str,
    client_ca_path: Option<&str>,
) -> Result<TlsIdentityInner, Error> {
    let identity = read_key_cert(key_path, cert_path).await?;
    make_server_config_from_mem(identity, client_ca_path)
}

/// Create a `TlsAcceptor` from a keypair generated by `rcgen`
#[cfg(feature = "acme")]
pub async fn make_server_config_from_pem(
    certs: String,
    priv_key_pem: String,
    client_ca_path: Option<&str>,
) -> Result<TlsIdentityInner, Error> {
    let identity = Identity::from_pkcs8(certs.as_bytes(), priv_key_pem.as_bytes())?;
    make_server_config_from_mem(identity, client_ca_path)
}

fn make_server_config_from_mem(
    identity: Identity,
    client_ca_path: Option<&str>,
) -> Result<TlsIdentityInner, Error> {
    if client_ca_path.is_some() {
        return Err(Error::UnsupportedFeature(
            "client CA verification",
            "requires rustls",
        ));
    }
    let raw_acceptor = TlsAcceptor::builder(identity).build()?;
    Ok(raw_acceptor.into())
}

/// Create a `TlsConnector` with possibly a client certificate and a custom CA store
pub async fn make_client_config(
    cert_path: Option<&str>,
    key_path: Option<&str>,
    ca_path: Option<&str>,
    tls_skip_verify: bool,
    tls_alpn: Option<&[&str]>,
) -> Result<TlsConnector, Error> {
    let mut tls_config_builder = TlsConnector::builder();
    tls_config_builder
        .danger_accept_invalid_certs(tls_skip_verify)
        .danger_accept_invalid_hostnames(tls_skip_verify);
    if let Some(tls_alpn) = tls_alpn {
        tls_config_builder.request_alpns(tls_alpn);
    }
    if let Some(ca_path) = ca_path {
        let ca = tokio::fs::read(ca_path).await.map_err(Error::ReadCert)?;
        tls_config_builder.add_root_certificate(Certificate::from_pem(&ca)?);
    }
    if let Some(cert_path) = cert_path {
        let identity = read_key_cert(key_path.unwrap_or(cert_path), cert_path).await?;
        tls_config_builder.identity(identity);
    }
    Ok(tls_config_builder.build()?)
}

async fn read_key_cert(key_path: &str, cert_path: &str) -> Result<Identity, Error> {
    let key = tokio::fs::read(key_path).await.map_err(Error::ReadCert)?;
    let cert = tokio::fs::read(cert_path).await.map_err(Error::ReadCert)?;
    Ok(Identity::from_pkcs8(&cert, &key)?)
}

/// For backend requests
#[cfg(feature = "server")]
#[expect(clippy::unnecessary_wraps)]
pub fn make_hyper_connector() -> std::io::Result<HyperConnector> {
    Ok(HyperConnector::new())
}

// `native_tls` on macOS and Windows doesn't support reading Ed25519 nor ECDSA-based certificates, but `rcgen` doesn't support generating RSA keys.
#[cfg(test)]
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
mod tests {
    use super::*;
    use rcgen::CertificateParams;
    use tempfile::tempdir;

    #[tokio::test]
    async fn test_read_key_cert() {
        crate::tests::setup_logging();
        let tmpdir = tempdir().unwrap();
        let key_path = tmpdir.path().join("key.pem");
        let cert_path = tmpdir.path().join("cert.pem");
        let cert_params = CertificateParams::new(vec!["example.com".into()]).unwrap();
        let keypair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P384_SHA384).unwrap();
        let custom_crt = cert_params.self_signed(&keypair).unwrap();
        let crt = custom_crt.pem();
        let crt_key = keypair.serialize_pem();
        tokio::fs::write(&cert_path, crt).await.unwrap();
        tokio::fs::write(&key_path, crt_key).await.unwrap();
        read_key_cert(key_path.to_str().unwrap(), cert_path.to_str().unwrap())
            .await
            .unwrap();
    }
    #[tokio::test]
    #[cfg(feature = "acme")]
    async fn test_make_server_config_from_rcgen_pem() {
        crate::tests::setup_logging();
        let cert_params = CertificateParams::new(vec!["example.com".into()]).unwrap();
        let keypair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P384_SHA384).unwrap();
        let custom_crt = cert_params.self_signed(&keypair).unwrap();
        let crt = custom_crt.pem();

        let result = make_server_config_from_pem(crt, keypair.serialize_pem(), None).await;

        assert!(result.is_ok());
    }
}