t2_bus 0.1.0

An inter- or intra-process message bus supporting publish/subscribe and request/response.
Documentation
#![cfg(feature = "tls")]

use std::convert::TryFrom;

use std::fs::File;
use std::io::BufReader;


use std::path::Path;
use std::sync::Arc;

use rustls_pemfile::certs;
use rustls_pemfile::private_key;
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio_rustls::rustls::pki_types::CertificateDer;
use tokio_rustls::rustls::pki_types::PrivateKeyDer;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::TlsConnector;
use tokio_rustls::client::TlsStream;
use tokio_util::codec::Framed;


use crate::server::listen::listen_and_serve;
use crate::stopper::MultiStopper;
use crate::{protocol::{Msg, ProtocolClient, ProtocolServer}, server::listen::Listener, err::BusResult, transport::CborCodec};

use super::BusError;
use super::Transport;

pub async fn serve(
    addr: impl ToSocketAddrs,
    certs_pem_file: &Path, 
    certs_key_file: &Path,
    root_cert_pem_file: &Path
) -> BusResult<MultiStopper> {
    let listener = TlsListener::new(addr, certs_pem_file, certs_key_file, root_cert_pem_file).await?;
    listen_and_serve(listener)
}

pub async fn connect (
    host: &str,
    port: u16,
    ca_file: &Path,
    certs_pem_file: &Path, 
    key_file: &Path
) -> BusResult<Framed<TlsStream<TcpStream>, CborCodec<Msg<ProtocolClient>, Msg<ProtocolServer>>>> {

    let mut root_cert_store = RootCertStore::empty();
    let mut pem = BufReader::new(File::open(ca_file)?);
    for cert in rustls_pemfile::certs(&mut pem) {
        root_cert_store.add(cert?).unwrap();
    }

    let cert_chain = load_certs(certs_pem_file)?;
    let key = load_key(key_file)?;

    let config = ClientConfig::builder()
        .with_root_certificates(root_cert_store)
        .with_client_auth_cert(cert_chain, key)
        .map_err(|e| BusError::TlsConfigError(e.to_string()))?;
    
    let connector = TlsConnector::from(Arc::new(config));

    let socket = TcpStream::connect(&format!("{host}:{port}")).await?;

    let domain = ServerName::try_from(host)
        .map_err(|_| BusError::TlsConfigError("invalid hostname".into()))?
        .to_owned();

    let tls_socket = connector.connect(domain, socket).await?;

    let transport = tokio_util::codec::Framed::new(tls_socket, CborCodec::new());
    Ok(transport)
}

pub (crate) struct TlsListener{
    listener: tokio::net::TcpListener,
    tls_acceptor: TlsAcceptor
}

impl TlsListener{
    pub(crate) async fn new(addr: impl ToSocketAddrs, certs_pem_file: &Path, key_file: &Path, root_cert_pem_file: &Path) -> BusResult<Self>{
        let listener = tokio::net::TcpListener::bind(addr).await?;
        let certs = load_certs(certs_pem_file)?;
        let key = load_key(key_file)?;
        let root_cert_store = load_root_cert_store(root_cert_pem_file)?;

        let client_verifier = WebPkiClientVerifier::builder(root_cert_store.into())
            .build()
            .map_err(|err| BusError::TlsConfigError(err.to_string()))?;

        let config = ServerConfig::builder()
            .with_client_cert_verifier(client_verifier)
            .with_single_cert(certs, key)
            .map_err(|err| BusError::TlsConfigError(err.to_string()))?;

        let tls_acceptor = TlsAcceptor::from(Arc::new(config));

        Ok(
            Self{
                listener,
                tls_acceptor
            }
        )
    }
}

impl Listener for TlsListener{
    async fn accept(&mut self) -> BusResult<impl Transport<ProtocolServer, ProtocolClient>> {
        let (socket, _) = self.listener.accept().await?;
        let tls_acceptor = self.tls_acceptor.clone();
        let tls_socket = tls_acceptor.accept(socket).await?;
        let transport = tokio_util::codec::Framed::new(tls_socket, CborCodec::new());
        
        Ok(transport)
    }
}

fn load_certs(path: &Path) -> BusResult<Vec<CertificateDer<'static>>> {
    certs(&mut BufReader::new(File::open(path)?)).map(|r| r.map_err(|e| BusError::TlsConfigError(e.to_string()))).collect()
}

fn load_root_cert_store(path: &Path) -> BusResult<RootCertStore> {
    let root_certs: BusResult<Vec<CertificateDer<'static>>> = certs(&mut BufReader::new(File::open(path)?)).map(|r| r.map_err(|e| BusError::TlsConfigError(e.to_string()))).collect();
    let root_certs = root_certs?;
    let mut cert_store = RootCertStore::empty();
    for cert in root_certs.into_iter() {
        cert_store.add(cert).map_err(|e| BusError::TlsConfigError(e.to_string()))?;
    }

    Ok(cert_store)
}

fn load_key(path: &Path) -> BusResult<PrivateKeyDer<'static>> {
    private_key(&mut BufReader::new(File::open(path)?))
        .unwrap()
        .ok_or(BusError::TlsConfigError(
            "no private key found".to_string(),
        ))
}