wild-doc-webserver 0.0.7

This is still in development.
use core::task::{Context, Poll};
use std::{sync::Arc, pin::Pin};

use futures_util::{ready, Future};
use hyper::server::{
    conn::{AddrStream, AddrIncoming}
    ,accept::Accept
};
use rustls::{ServerConfig, server::ResolvesServerCertUsingSni, sign::RsaSigningKey};
use tokio::io::{AsyncRead, ReadBuf, AsyncWrite};

enum State {
    Handshaking(tokio_rustls::Accept<AddrStream>),
    Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
pub struct TlsStream {
    state: State,
}
impl TlsStream {
    fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
        let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
        TlsStream {
            state: State::Handshaking(accept),
        }
    }
}
impl AsyncRead for TlsStream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &mut ReadBuf,
    ) -> Poll<std::io::Result<()>> {
        let pin = self.get_mut();
        match pin.state {
            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
                Ok(mut stream) => {
                    let result = Pin::new(&mut stream).poll_read(cx, buf);
                    pin.state = State::Streaming(stream);
                    result
                }
                Err(err) => Poll::Ready(Err(err)),
            },
            State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
        }
    }
}
impl AsyncWrite for TlsStream {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        let pin = self.get_mut();
        match pin.state {
            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
                Ok(mut stream) => {
                    let result = Pin::new(&mut stream).poll_write(cx, buf);
                    pin.state = State::Streaming(stream);
                    result
                }
                Err(err) => Poll::Ready(Err(err)),
            },
            State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
        }
    }
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.state {
            State::Handshaking(_) => Poll::Ready(Ok(())),
            State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
        }
    }
    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.state {
            State::Handshaking(_) => Poll::Ready(Ok(())),
            State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
        }
    }
}

pub struct TlsAcceptor {
    config: Arc<ServerConfig>,
    incoming: AddrIncoming,
}
impl TlsAcceptor {
    pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
        TlsAcceptor { config, incoming }
    }
}
impl Accept for TlsAcceptor {
    type Conn = TlsStream;
    type Error = std::io::Error;

    fn poll_accept(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        let pin = self.get_mut();
        match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
            Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
            Some(Err(e)) => Poll::Ready(Some(Err(e))),
            None => Poll::Ready(None),
        }
    }
}

pub(super) fn error(err: String)->std::io::Error {
    std::io::Error::new(std::io::ErrorKind::Other, err)
}

pub(super) fn add_certificate_to_resolver(
    name: &str, hostname: &str,
    resolver: &mut ResolvesServerCertUsingSni
) {
    resolver.add(hostname, rustls::sign::CertifiedKey::new(
        load_certs(&format!("certificates/{}/fullchain.pem", name)).unwrap()
        , Arc::new(
            RsaSigningKey::new(
                &load_private_key(&format!("certificates/{}/privkey.pem", name)).unwrap()
            ).unwrap()
        )
    )).expect(&("Invalid certificate for ".to_owned()+hostname));
}

pub(super) fn load_certs(filename: &str)->std::io::Result<Vec<rustls::Certificate>>{
    let certs = rustls_pemfile::certs(&mut std::io::BufReader::new(
        std::fs::File::open(filename).map_err(|e| error(format!("failed to open {}: {}", filename, e)))?
    ))
        .map_err(|_| error("failed to load certificate".into()))?;
    Ok(
        certs.into_iter().map(rustls::Certificate).collect()
    )
}

pub(super) fn load_private_key(filename: &str)->std::io::Result<rustls::PrivateKey> {
    let keys=rustls_pemfile::rsa_private_keys(&mut std::io::BufReader::new(
        std::fs::File::open(filename).map_err(|e| error(format!("failed to open {}: {}", filename, e)))?
    )).map_err(|_| error("failed to load private key".into()))?;
    if keys.len() != 1 {
        return Err(error("expected a single private key".into()));
    }
    Ok(rustls::PrivateKey(keys[0].clone()))
}