endpoint-libs 1.7.28

Common dependencies to be used with Pathscale projects, projects that use [endpoint-gen](https://github.com/pathscale/endpoint-gen), and projects that use honey_id-types.
Documentation
use std::fs::File;
use std::net::SocketAddr;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;

use eyre::{Context, Result, ensure};
use futures::FutureExt;
use futures::future::BoxFuture;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivateKeyDer;
use rustls::pki_types::pem::PemObject;
use rustls_pemfile::certs;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::server::TlsStream;

use super::ConnectionListener;

pub struct TlsListener<T> {
    tcp: T,
    acceptor: TlsAcceptor,
}

impl<T: ConnectionListener> TlsListener<T> {
    pub async fn bind(under: T, pub_certs: Vec<PathBuf>, priv_cert: PathBuf) -> Result<Self> {
        let certs = load_certs(&pub_certs)?;
        ensure!(
            !certs.is_empty(),
            "No certificates found in file: {:?}",
            pub_certs
        );

        let key = load_private_key(&priv_cert)?;

        #[cfg(feature = "ws-tls12")]
        let protocol_versions: &[&rustls::SupportedProtocolVersion] =
            { &[&rustls::version::TLS13, &rustls::version::TLS12] };
        #[cfg(not(feature = "ws-tls12"))]
        let protocol_versions: &[&rustls::SupportedProtocolVersion] =
            { &[&rustls::version::TLS13] };
        #[cfg(feature = "ws-wtx")]
        let protocol_versions: &[&rustls::SupportedProtocolVersion] =
            { &[&rustls::version::TLS12] };

        #[cfg(any(
            all(feature = "ws", feature = "ws-http1"),
            all(feature = "ws-wtx", feature = "ws-wtx-http2"),
        ))]
        let alpn_protocols = { vec![b"h2".to_vec(), b"http/1.1".to_vec()] };
        #[cfg(all(feature = "ws-wtx", not(feature = "ws-wtx-http2")))]
        let alpn_protocols = { vec![b"http/1.1".to_vec()] };
        #[cfg(all(feature = "ws", not(feature = "ws-http1")))]
        let alpn_protocols = { vec![b"h2".to_vec()] };

        let tls_cfg = {
            let mut cfg = rustls::ServerConfig::builder_with_protocol_versions(protocol_versions)
                .with_no_client_auth()
                .with_single_cert(certs, key)?;
            cfg.alpn_protocols = alpn_protocols;
            Arc::new(cfg)
        };
        let acceptor = TlsAcceptor::from(tls_cfg);
        Ok(Self {
            tcp: under,
            acceptor,
        })
    }
}

impl<T: ConnectionListener + 'static> ConnectionListener for TlsListener<T> {
    type Channel1 = T::Channel1;
    type Channel2 = TlsStream<T::Channel2>;
    fn accept(&self) -> BoxFuture<'_, Result<(Self::Channel1, SocketAddr)>> {
        self.tcp.accept()
    }
    fn handshake(&self, channel: Self::Channel1) -> BoxFuture<'_, Result<Self::Channel2>> {
        async {
            let channel = self.tcp.handshake(channel).await?;
            let tls_stream = self.acceptor.accept(channel).await?;
            Ok(tls_stream)
        }
        .boxed()
    }
}

fn load_certs<'a, T: AsRef<Path>>(
    path: impl IntoIterator<Item = T>,
) -> Result<Vec<CertificateDer<'a>>> {
    let mut r_certs = vec![];
    for p in path {
        let p = p.as_ref();
        let f = File::open(p).with_context(|| format!("Failed to open {}", p.display()))?;

        let reader = &mut std::io::BufReader::new(f);
        let certs_results = certs(reader);

        let certs: Vec<CertificateDer> = certs_results.filter_map(|result| result.ok()).collect();

        r_certs.extend(certs);
    }
    Ok(r_certs)
}

fn load_private_key(path: &PathBuf) -> Result<PrivateKeyDer<'static>> {
    let private_key =
        PrivateKeyDer::from_pem_file(path).wrap_err("Error loading private key from file.")?;

    Ok(private_key)
}