via 2.0.0-gm.48

An async multi-threaded web framework for people who appreciate simplicity.
Documentation
use native_tls::{Identity, Protocol};
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_native_tls::{TlsAcceptor, TlsStream};

use crate::server::tls::NegotiateAlpn;

use super::{Acceptor, Alpn};

pub struct NativeTlsAcceptor(Arc<TlsAcceptor>);

pub struct NativeTlsStream {
    alpn: Alpn,
    stream: TlsStream<TcpStream>,
}

impl NativeTlsAcceptor {
    pub fn new(identity: Identity) -> Self {
        Self(Arc::new(TlsAcceptor::from(
            native_tls::TlsAcceptor::builder(identity)
                .min_protocol_version(Some(Protocol::Tlsv12))
                .build()
                .expect("tls config is invalid or missing"),
        )))
    }
}

impl Acceptor for NativeTlsAcceptor {
    type Stream = NativeTlsStream;
    type Error = native_tls::Error;

    fn accept(
        &self,
        io: TcpStream,
    ) -> impl Future<Output = Result<Self::Stream, Self::Error>> + Send + 'static {
        let acceptor = Arc::clone(&self.0);

        async move {
            let stream = acceptor.accept(io).await?;
            let inner = stream.get_ref();
            let alpn = match inner.negotiated_alpn()? {
                Some(value) if value == b"h2" => Alpn::HTTP_2,
                _ => Alpn::HTTP_11,
            };

            Ok(NativeTlsStream { alpn, stream })
        }
    }
}

impl AsyncRead for NativeTlsStream {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &mut ReadBuf,
    ) -> Poll<io::Result<()>> {
        Pin::new(&mut self.stream).poll_read(cx, buf)
    }
}

impl AsyncWrite for NativeTlsStream {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut self.stream).poll_write(cx, buf)
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
        Pin::new(&mut self.stream).poll_flush(cx)
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
        Pin::new(&mut self.stream).poll_shutdown(cx)
    }

    fn is_write_vectored(&self) -> bool {
        self.stream.is_write_vectored()
    }

    fn poll_write_vectored(
        mut self: Pin<&mut Self>,
        cx: &mut Context,
        bufs: &[io::IoSlice],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
    }
}

impl NegotiateAlpn for NativeTlsStream {
    fn preferred_alpn(&self) -> &Alpn {
        &self.alpn
    }
}