portfu_core 1.3.3

Portfu Core Types and Definitions Library
Documentation
use crate::server::ServerConfig;
use log::error;
use rand::Rng;
use rsa::pkcs1::DecodeRsaPrivateKey;
use rsa::pkcs1v15::SigningKey;
use rsa::pkcs8::{DecodePrivateKey, EncodePrivateKey, EncodePublicKey};
use rustls::client::danger::HandshakeSignatureValid;
use rustls::crypto::ring::default_provider;
use rustls::crypto::ring::sign::RsaSigningKey;
use rustls::pki_types::{CertificateDer, DnsName, PrivateKeyDer, ServerName, UnixTime};
use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
use rustls::server::{ClientHello, ParsedCertificate, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::{DigitallySignedStruct, DistinguishedName, RootCertStore, SignatureScheme};
use rustls_pemfile::{certs, read_one, Item};
use sha2::Sha256;
use std::collections::HashMap;
use std::env;
use std::fmt::Debug;
use std::io::{BufReader, Error, ErrorKind};
use std::ops::Sub;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use x509_cert::builder::{Builder, CertificateBuilder, Profile};
use x509_cert::der::asn1::{Ia5String, UtcTime};
use x509_cert::der::pem::LineEnding;
use x509_cert::der::{DateTime, DecodePem, EncodePem};
use x509_cert::ext::pkix::name::GeneralName;
use x509_cert::ext::pkix::SubjectAltName;
use x509_cert::name::Name;
use x509_cert::serial_number::SerialNumber;
use x509_cert::spki::SubjectPublicKeyInfo;
use x509_cert::time::{Time, Validity};
use x509_cert::Certificate;

pub fn load_ssl_certs(config: &ServerConfig) -> Result<Arc<rustls::ServerConfig>, Error> {
    default_provider().install_default().unwrap_or_default();
    let (certs, key, root_certs) = if let Some(ssl_info) = &config.ssl_config {
        (
            load_certs(ssl_info.certs.as_bytes())?,
            load_private_key(ssl_info.key.as_bytes())?,
            load_certs(ssl_info.root_certs.as_bytes())?,
        )
    } else if let (Some(crt), Some(key)) = (
        env::var("PRIVATE_CA_CRT").ok(),
        env::var("PRIVATE_CA_KEY").ok(),
    ) {
        let (cert_bytes, key_bytes) = generate_ca_signed_cert(crt.as_bytes(), key.as_bytes())?;
        (
            load_certs(&cert_bytes)?,
            load_private_key(&key_bytes)?,
            load_certs(crt.as_bytes())?,
        )
    } else if let (Some(certs), Some(key), Some(root_certs)) = (
        env::var("SSL_CERTS").ok(),
        env::var("SSL_PRIVATE_KEY").ok(),
        env::var("SSL_ROOT_CERTS").ok(),
    ) {
        (
            load_certs(certs.as_bytes())?,
            load_private_key(key.as_bytes())?,
            load_certs(root_certs.as_bytes())?,
        )
    } else {
        return Err(Error::new(ErrorKind::InvalidInput, "Invalid SSL Config"));
    };
    let mut root_cert_store = RootCertStore::empty();
    for cert in root_certs {
        root_cert_store.add(cert).map_err(|e| {
            Error::new(
                ErrorKind::InvalidInput,
                format!("Invalid Root Cert for Server: {:?}", e),
            )
        })?;
    }
    let mut resolver = ResolvesServerCertUsingSniWithDefault::new();
    let name = config
        .ssl_config
        .as_ref()
        .map(|c| c.domain.as_str())
        .unwrap_or("localhost");
    let cer_key = CertifiedKey::new(
        certs,
        Arc::new(RsaSigningKey::new(&key).map_err(|e| {
            Error::new(
                ErrorKind::InvalidInput,
                format!("Private Key is not Valid SigningKey: {:?}", e),
            )
        })?),
    );
    resolver.add(name, cer_key).map_err(|e| {
        Error::new(
            ErrorKind::InvalidInput,
            format!("Failed to add SSL Certs to Resolver: {:?}", e),
        )
    })?;
    match &config.client_ssl_config {
        None => {
            let resolver = Arc::new(resolver);
            Ok(Arc::new(
                rustls::ServerConfig::builder()
                    .with_no_client_auth()
                    .with_cert_resolver(resolver),
            ))
        }
        Some(client_ssl) => {
            let (certs, key, root_certs) = (
                load_certs(client_ssl.certs.as_bytes())?,
                load_private_key(client_ssl.key.as_bytes())?,
                load_certs(client_ssl.root_certs.as_bytes())?,
            );
            for cert in root_certs {
                root_cert_store.add(cert).map_err(|e| {
                    Error::new(
                        ErrorKind::InvalidInput,
                        format!("Invalid Root Cert for Server: {:?}", e),
                    )
                })?;
            }
            let cer_key = CertifiedKey::new(
                certs,
                Arc::new(RsaSigningKey::new(&key).map_err(|e| {
                    Error::new(
                        ErrorKind::InvalidInput,
                        format!("Private Key is not Valid SigningKey: {:?}", e),
                    )
                })?),
            );
            resolver
                .add(client_ssl.domain.as_str(), cer_key)
                .map_err(|e| {
                    Error::new(
                        ErrorKind::InvalidInput,
                        format!("Failed to add SSL Certs to Resolver: {:?}", e),
                    )
                })?;
            let resolver = Arc::new(resolver);
            Ok(Arc::new(
                rustls::ServerConfig::builder()
                    .with_client_cert_verifier(AllowAny::new())
                    .with_cert_resolver(resolver),
            ))
        }
    }
}
pub fn load_certs(bytes: &[u8]) -> Result<Vec<CertificateDer<'static>>, Error> {
    let mut reader = BufReader::new(bytes);
    let certs = certs(&mut reader);
    Ok(certs.into_iter().flatten().collect())
}

pub fn load_private_key(bytes: &[u8]) -> Result<PrivateKeyDer<'static>, Error> {
    let mut reader = BufReader::new(bytes);
    for item in std::iter::from_fn(|| read_one(&mut reader).transpose()) {
        if let Some(item) = handle_item(item) {
            return Ok(item);
        }
    }
    Err(Error::new(ErrorKind::NotFound, "Private Key Not Found"))
}

pub fn generate_ca_signed_cert(
    cert_data: &[u8],
    key_data: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), Error> {
    let root_cert = Certificate::from_pem(cert_data).map_err(|e| Error::other(format!("{e:?}")))?;
    let root_key = rsa::RsaPrivateKey::from_pkcs1_pem(&String::from_utf8_lossy(key_data))
        .or_else(|_| rsa::RsaPrivateKey::from_pkcs8_pem(&String::from_utf8_lossy(key_data)))
        .map_err(|e| Error::other(format!("Failed to load Root Key: {e:?}")))?;
    let mut rng = rand::thread_rng();
    let cert_key =
        rsa::RsaPrivateKey::new(&mut rng, 2048).map_err(|e| Error::other(format!("{e:?}")))?;
    let pub_key = cert_key.to_public_key();
    let signing_key: SigningKey<Sha256> = SigningKey::new(root_key);
    let subject_pub_key = SubjectPublicKeyInfo::from_pem(
        pub_key
            .to_public_key_pem(LineEnding::default())
            .map_err(|e| Error::other(format!("{e:?}")))?
            .as_bytes(),
    )
    .map_err(|e| Error::other(format!("{e:?}")))?;
    let mut cert = CertificateBuilder::new(
        Profile::Leaf {
            issuer: root_cert.tbs_certificate.issuer,
            enable_key_agreement: false,
            enable_key_encipherment: false,
        },
        SerialNumber::from(rng.gen::<u32>()),
        Validity {
            not_before: Time::UtcTime(
                UtcTime::from_system_time(SystemTime::now().sub(Duration::from_secs(60 * 60 * 24)))
                    .map_err(|e| Error::other(format!("{e:?}")))?,
            ),
            not_after: Time::UtcTime(
                UtcTime::from_date_time(
                    DateTime::new(2049, 8, 2, 0, 0, 0)
                        .map_err(|e| Error::other(format!("{e:?}")))?,
                )
                .map_err(|e| Error::other(format!("{e:?}")))?,
            ),
        },
        Name::from_str("CN=Chia,O=Chia,OU=Organic Farming Division")
            .map_err(|e| Error::other(format!("{e:?}")))?,
        subject_pub_key,
        &signing_key,
    )
    .map_err(|e| Error::other(format!("{e:?}")))?;
    cert.add_extension(&SubjectAltName(vec![GeneralName::DnsName(
        Ia5String::new("chia.net").map_err(|e| Error::other(format!("{e:?}")))?,
    )]))
    .map_err(|e| Error::other(format!("{e:?}")))?;
    let cert = cert.build().map_err(|e| Error::other(format!("{e:?}")))?;
    Ok((
        cert.to_pem(LineEnding::default())
            .map_err(|e| Error::other(format!("{e:?}")))?
            .as_bytes()
            .to_vec(),
        cert_key
            .to_pkcs8_pem(LineEnding::default())
            .map_err(|e| Error::other(format!("{e:?}")))?
            .as_bytes()
            .to_vec(),
    ))
}

fn handle_item(item: Result<Item, Error>) -> Option<PrivateKeyDer<'static>> {
    match item {
        Ok(Item::Pkcs8Key(key)) => {
            return Some(PrivateKeyDer::Pkcs8(key));
        }
        Ok(Item::Pkcs1Key(key)) => {
            return Some(PrivateKeyDer::Pkcs1(key));
        }
        Ok(Item::Sec1Key(key)) => {
            return Some(PrivateKeyDer::Sec1(key));
        }
        Ok(Item::X509Certificate(_)) => error!("Found Certificate, not Private Key"),
        Ok(Item::Crl(_)) => error!("Found Crl, not Private Key"),
        Ok(Item::Csr(_)) => error!("Found Csr, not Private Key"),
        _ => {
            error!("Unknown Item while loading private key")
        }
    }
    None
}

#[derive(Debug)]
pub struct AllowAny {}
impl AllowAny {
    #[must_use]
    pub fn new() -> Arc<Self> {
        Arc::new(Self {})
    }
}
impl ClientCertVerifier for AllowAny {
    fn client_auth_mandatory(&self) -> bool {
        false
    }

    fn root_hint_subjects(&self) -> &[DistinguishedName] {
        &[]
    }

    fn verify_client_cert(
        &self,
        _end_entity: &CertificateDer<'_>,
        _intermediates: &[CertificateDer<'_>],
        _now: UnixTime,
    ) -> Result<ClientCertVerified, rustls::Error> {
        Ok(ClientCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, rustls::Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, rustls::Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
        vec![
            SignatureScheme::ECDSA_NISTP521_SHA512,
            SignatureScheme::ECDSA_NISTP384_SHA384,
            SignatureScheme::ECDSA_NISTP256_SHA256,
            SignatureScheme::RSA_PKCS1_SHA256,
            SignatureScheme::RSA_PKCS1_SHA384,
            SignatureScheme::RSA_PKCS1_SHA512,
            SignatureScheme::RSA_PSS_SHA512,
            SignatureScheme::RSA_PSS_SHA384,
            SignatureScheme::RSA_PSS_SHA256,
            SignatureScheme::ED25519,
        ]
    }
}

#[derive(Debug)]
pub struct ResolvesServerCertUsingSniWithDefault {
    by_name: HashMap<String, Arc<CertifiedKey>>,
}
impl ResolvesServerCertUsingSniWithDefault {
    pub fn new() -> Self {
        Self {
            by_name: HashMap::new(),
        }
    }
    pub fn add(&mut self, name: &str, ck: CertifiedKey) -> Result<(), Error> {
        let server_name = {
            let checked_name = DnsName::try_from(name)
                .map_err(|_| Error::new(ErrorKind::InvalidInput, "Bad DNS name"))
                .map(|name| name.to_lowercase_owned())?;
            ServerName::DnsName(checked_name)
        };
        ck.end_entity_cert()
            .and_then(ParsedCertificate::try_from)
            .map_err(|_| Error::new(ErrorKind::InvalidInput, "Bad Entity Cert"))?;
        if let ServerName::DnsName(name) = server_name {
            self.by_name.insert(name.as_ref().to_string(), Arc::new(ck));
        }
        Ok(())
    }
}

impl ResolvesServerCert for ResolvesServerCertUsingSniWithDefault {
    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
        if let Some(name) = client_hello.server_name() {
            self.by_name.get(name).cloned()
        } else {
            self.by_name.values().next().cloned()
        }
    }
}