tls-async 0.3.0-alpha.7

TLS support for AsyncRead/AsyncWrite using native-tls
Documentation
use crate::errors::Error;
use crate::TlsStream;

use std::pin::Pin;
use std::task::Context;

use futures::Future;
use futures::compat::Compat;
use futures::io::{AsyncRead, AsyncWrite};
use futures::Poll;
use log::debug;
use native_tls::{HandshakeError, MidHandshakeTlsStream};
pub use native_tls::{TlsConnector as NativeTlsConnector, TlsStream as NativeTlsStream};

enum Handshake<S> {
    Error(Error),
    Midhandshake(MidHandshakeTlsStream<Compat<S>>),
    Completed(NativeTlsStream<Compat<S>>),
}

impl<S> Handshake<S> {
    pub fn was_pending(&self) -> bool {
        if let Handshake::Midhandshake(_) = self {
            true
        } else {
            false
        }
    }
}

type NativeHandshake<S> = Result<NativeTlsStream<Compat<S>>, HandshakeError<Compat<S>>>;

impl<S> From<NativeHandshake<S>> for Handshake<S> {
    fn from(v: NativeHandshake<S>) -> Self {
        match v {
            Ok(native_stream) => Handshake::Completed(native_stream),
            Err(HandshakeError::Failure(e)) => Handshake::Error(Error::Handshake(e)),
            Err(HandshakeError::WouldBlock(midhandshake_stream)) => Handshake::Midhandshake(midhandshake_stream),
        }
    }
}

pub struct PendingTlsStream<S> {
    inner: Handshake<S>,
}

impl<S> PendingTlsStream<S> {
    pub fn new(inner: NativeHandshake<S>) -> Self {
        PendingTlsStream {
            inner: Handshake::from(inner)
        }
    }
    fn inner<'a>(self: Pin<&'a mut Self>) -> &'a mut Handshake<S> {
        unsafe {
            &mut Pin::get_unchecked_mut(self).inner
        }
    }
}

impl<S: AsyncRead + AsyncWrite + std::fmt::Debug + Unpin> Future for PendingTlsStream<S> {
    type Output = Result<TlsStream<S>, Error>;

    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
        loop {
            let handshake = std::mem::replace(self.as_mut().inner(), Handshake::Error(Error::RepeatedHandshake));
            match handshake {
                Handshake::Error(e) => return Poll::Ready(Err(e)),
                Handshake::Midhandshake(midhandshake_stream) => {
                    debug!("Connection was interrupted mid handshake, attempting handshake");
                    let res = Handshake::from(midhandshake_stream.handshake());
                    let was_pending = res.was_pending();
                    *self.as_mut().inner() = res;
                    if was_pending {
                        return Poll::Pending;
                    }
                }
                Handshake::Completed(native_stream) => {
                    debug!("Connection was completed");
                    return Poll::Ready(Ok(TlsStream { inner: native_stream }))
                }
            }
        }
    }
}