lers 0.4.0

An async, user-friendly Let's Encrypt/ACMEv2 library written in Rust
Documentation
use super::stream::{AllowStd, TlsAcceptor};
use futures::join;
use native_tls::Certificate;
use openssl::{
    pkcs12::Pkcs12,
    ssl::{SslAcceptor, SslMethod, SslStream},
    x509::X509VerifyResult,
};
use std::{
    fs,
    io::{Error, ErrorKind},
    iter,
};
use tokio::{
    io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
    net::{TcpListener, TcpStream},
};
use tokio_native_tls::TlsConnector;

#[tokio::test]
async fn client_to_server() {
    let srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
    let addr = srv.local_addr().unwrap();

    let (server_tls, client_tls) = context();

    let server = async move {
        let (socket, _) = srv.accept().await.unwrap();
        let mut socket = server_tls.accept(socket).await.unwrap();

        let openssl_stream: &SslStream<_> = socket.get_ref();
        assert_eq!(openssl_stream.ssl().verify_result(), X509VerifyResult::OK);
        let allow_std_stream: &AllowStd<_> = openssl_stream.get_ref();
        let _tokio_tcp_stream: &TcpStream = allow_std_stream.get_ref();

        let mut data = Vec::new();
        socket.read_to_end(&mut data).await.unwrap();
        data
    };

    let client = async move {
        let socket = TcpStream::connect(&addr).await.unwrap();
        let socket = client_tls.connect("foobar.com", socket).await.unwrap();
        copy(socket).await
    };

    let (data, _) = join!(server, client);
    assert_eq!(data, vec![9; AMOUNT]);
}

#[tokio::test]
async fn server_to_client() {
    let srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
    let addr = srv.local_addr().unwrap();

    let (server_tls, client_tls) = context();

    let server = async move {
        let (socket, _) = srv.accept().await.unwrap();
        let socket = server_tls.accept(socket).await.unwrap();
        copy(socket).await
    };

    let client = async move {
        let socket = TcpStream::connect(&addr).await.unwrap();
        let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();

        let mut data = Vec::new();
        socket.read_to_end(&mut data).await.unwrap();
        data
    };

    let (_, data) = join!(server, client);
    assert_eq!(data, vec![9; AMOUNT]);
}

#[tokio::test]
async fn one_byte_at_a_time() {
    const AMOUNT: usize = 1024;

    let srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
    let addr = srv.local_addr().unwrap();

    let (server_tls, client_tls) = context();

    let server = async move {
        let (socket, _) = srv.accept().await.unwrap();
        let mut socket = server_tls.accept(socket).await.unwrap();

        let mut sent = 0;
        for b in iter::repeat(9).take(AMOUNT) {
            let data = [b as u8];
            socket.write_all(&data).await.unwrap();
            sent += 1;
        }
        sent
    };

    let client = async move {
        let socket = TcpStream::connect(&addr).await.unwrap();
        let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();

        let mut data = Vec::new();
        loop {
            let mut buf = [0; 1];
            match socket.read_exact(&mut buf).await {
                Ok(_) => data.extend_from_slice(&buf),
                Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break,
                Err(err) => panic!("{}", err),
            }
        }
        data
    };

    let (amount, data) = join!(server, client);
    assert_eq!(amount, AMOUNT);
    assert_eq!(data, vec![9; AMOUNT]);
}

fn context() -> (TlsAcceptor, TlsConnector) {
    let pkcs12 = fs::read("testdata/tls-alpn-01/identity.p12").unwrap();
    let pkcs12 = Pkcs12::from_der(&pkcs12).unwrap();
    let parsed = pkcs12.parse2("mypass").unwrap();

    let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
    acceptor.set_private_key(&parsed.pkey.unwrap()).unwrap();
    acceptor.set_certificate(&parsed.cert.unwrap()).unwrap();
    parsed
        .ca
        .into_iter()
        .flatten()
        .rev()
        .for_each(|c| acceptor.add_extra_chain_cert(c).unwrap());
    acceptor.set_min_proto_version(None).unwrap();
    acceptor.set_max_proto_version(None).unwrap();
    let acceptor = acceptor.build();

    let der = fs::read("testdata/tls-alpn-01/root-ca.der").unwrap();
    let cert = Certificate::from_der(&der).unwrap();
    let connector = native_tls::TlsConnector::builder()
        .add_root_certificate(cert)
        .build()
        .unwrap();

    (acceptor.into(), connector.into())
}

const AMOUNT: usize = 128 * 1024;

async fn copy<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
    let mut data = vec![9; AMOUNT];
    let mut copied = 0;

    while !data.is_empty() {
        let written = w.write(&data).await?;
        if written <= data.len() {
            copied += written;
            data.resize(data.len() - written, 0);
        } else {
            w.write_all(&data).await?;
            copied += data.len();
            break;
        }

        println!("remaining: {}", data.len());
    }

    Ok(copied)
}