ocapn-netlayer 0.1.4

OCapN transport layer interfaces and types
Documentation
use std::net::ToSocketAddrs;
use std::net::{Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::sync::Arc;

use anyhow::{Error, Result};
use async_trait::async_trait;
use tokio::io::split;
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::rustls::pki_types::pem::PemObject;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::{rustls, TlsAcceptor};

use crate::connector::Connector;

pub struct Config {
    listen_addr: SocketAddr,
    ca_cert_file: Option<PathBuf>,
    server_cert_file: PathBuf,
    server_key_file: PathBuf,
}

const DEFAULT_TLS_PORT: u16 = 9445;

impl Default for Config {
    fn default() -> Self {
        Config {
            listen_addr: SocketAddr::new(
                std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
                DEFAULT_TLS_PORT,
            ),
            ca_cert_file: None,
            // TODO: generate these somewhere in a tempdir?
            server_cert_file: PathBuf::from("cert.pem"),
            server_key_file: PathBuf::from("key.pem"),
        }
    }
}

pub(crate) struct TlsConnector {
    connector: tokio_rustls::TlsConnector,

    acceptor: TlsAcceptor,
    listener: TcpListener,
}

impl TryFrom<Config> for TlsConnector {
    type Error = Error;

    fn try_from(config: Config) -> Result<Self, Self::Error> {
        let mut root_cert_store = RootCertStore::empty();
        if let Some(cafile) = config.ca_cert_file {
            for cert in CertificateDer::pem_file_iter(cafile)? {
                root_cert_store.add(cert?)?;
            }
        }
        let tls_client_config = rustls::ClientConfig::builder()
            .with_root_certificates(root_cert_store)
            .with_no_client_auth(); // i guess this was previously the default?
        let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));

        let server_certs = CertificateDer::pem_file_iter(&config.server_cert_file)?
            .collect::<Result<Vec<_>, _>>()?;
        let server_key = PrivateKeyDer::from_pem_file(&config.server_key_file)?;
        let tls_server_config = rustls::ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(server_certs, server_key)?;
        let acceptor = TlsAcceptor::from(Arc::new(tls_server_config));
        let listener = std::net::TcpListener::bind(config.listen_addr)?;

        Ok(TlsConnector {
            acceptor,
            connector,
            listener: TcpListener::from_std(listener)?,
        })
    }
}

#[async_trait]
impl Connector for TlsConnector {
    async fn new_outgoing_connection(
        &mut self,
        locator: &crate::locator::Peer,
    ) -> Result<crate::connection::Connection> {
        if !locator.transport.eq("tls") {
            return Err(Error::msg("not a tls address"));
        }
        let port: u16 = match locator.hints.get("port") {
            Some(port_as_str) => port_as_str.parse()?,
            None => DEFAULT_TLS_PORT,
        };
        let addr = (locator.designator.as_str(), port)
            .to_socket_addrs()?
            .next()
            .ok_or_else(|| Error::msg("could not resolve designator host"))?;
        let domain = ServerName::try_from(locator.designator.as_str())?.to_owned();
        let tcp_stream = TcpStream::connect(addr).await?;
        let tls_stream = self.connector.connect(domain, tcp_stream).await?;
        let split_stream = split(tls_stream);
        Ok(split_stream.into())
    }

    async fn accept_incoming_connection(
        &mut self,
    ) -> Result<Option<crate::connection::Connection>> {
        let (tcp_stream, _) = self.listener.accept().await?;
        let acceptor = self.acceptor.clone();
        let tls_stream = acceptor.accept(tcp_stream).await?;
        let split_stream = split(tls_stream);
        Ok(Some(split_stream.into()))
    }
}