hyperdriver 0.12.3

The missing middle for Hyper - Servers and Clients with ergonomic APIs
Documentation
//! Test TLS support in braided streams.

use std::{net::Ipv4Addr, sync::Arc};

use rustls::ServerConfig;

fn tls_config() -> rustls::ServerConfig {
    let (_, cert) =
        pem_rfc7468::decode_vec(include_bytes!("../minica/example.com/cert.pem")).unwrap();
    let (label, key) =
        pem_rfc7468::decode_vec(include_bytes!("../minica/example.com/key.pem")).unwrap();

    let cert = rustls::pki_types::CertificateDer::from(cert);
    let key = match label {
        "PRIVATE KEY" => rustls::pki_types::PrivateKeyDer::Pkcs8(key.into()),
        "RSA PRIVATE KEY" => rustls::pki_types::PrivateKeyDer::Pkcs1(key.into()),
        "EC PRIVATE KEY" => rustls::pki_types::PrivateKeyDer::Sec1(key.into()),
        _ => panic!("unknown key type"),
    };

    ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(vec![cert], key)
        .unwrap()
}

fn tls_root_store() -> rustls::RootCertStore {
    let mut root_store = rustls::RootCertStore::empty();
    let (_, cert) = pem_rfc7468::decode_vec(include_bytes!("../minica/minica.pem")).unwrap();
    root_store
        .add(rustls::pki_types::CertificateDer::from(cert))
        .unwrap();
    root_store
}

fn tls_install_default() {
    let _ = rustls::crypto::ring::default_provider().install_default();
}

#[tokio::test]
async fn braided_tls() {
    use futures_util::StreamExt;
    use tokio::io::{AsyncReadExt, AsyncWriteExt};

    tls_install_default();

    let incoming = tokio::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
        .await
        .unwrap();
    let addr = incoming.local_addr().unwrap();

    let server =
        hyperdriver::server::conn::Acceptor::from(incoming).with_tls(Arc::new(tls_config()));

    tokio::spawn(async move {
        let mut incoming = server.fuse();
        while let Some(Ok(mut stream)) = incoming.next().await {
            let mut buf = [0u8; 1024];
            let n = stream.read(&mut buf).await.unwrap();
            stream.write_all(&buf[..n]).await.unwrap();
        }
    });

    let mut conn = hyperdriver::client::conn::Stream::connect(addr)
        .await
        .unwrap()
        .tls(
            "example.com",
            rustls::ClientConfig::builder()
                .with_root_certificates(tls_root_store())
                .with_no_client_auth()
                .into(),
        );

    let mut buf = [0u8; 1024];
    conn.write_all(b"hello world").await.unwrap();
    let n = conn.read(&mut buf).await.unwrap();
    assert_eq!(&buf[..n], b"hello world");
}