openssl-async 0.3.0-alpha.5

Wrappers for the OpenSSL crate to allow use in async applications
Documentation
use std::error::Error as StdError;
use std::fmt;
use std::pin::Pin;

use openssl::error::ErrorStack;
use openssl::ssl;
use openssl::x509::X509VerifyResult;

use futures::io::{AsyncRead, AsyncWrite};
use futures::prelude::*;
use futures::task::{Context, Poll};

use async_stdio::*;

use crate::SslStream;

/// An SSL stream mid-handshake
///
/// This wraps the [ssl::MidHandshakeSslStream] type to make it a
/// [Future] that resolves to an [SslStream].
pub struct MidHandshakeSslStream<S>(HandshakeInner<S>);

impl<S: fmt::Debug> fmt::Debug for MidHandshakeSslStream<S> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_tuple("MidHandshakeSslStream")
            .field(&self.0)
            .finish()
    }
}

impl<S> MidHandshakeSslStream<S> {
    pub(crate) fn new(inner: ssl::MidHandshakeSslStream<AsStdIo<S>>, ctrl: WakerCtrl) -> Self {
        MidHandshakeSslStream(HandshakeInner {
            inner: Some((inner, ctrl)),
        })
    }
}

struct HandshakeInner<S> {
    inner: Option<(ssl::MidHandshakeSslStream<AsStdIo<S>>, WakerCtrl)>,
}

impl<S: fmt::Debug> fmt::Debug for HandshakeInner<S> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_struct("SslStream")
            .field("inner", &self.inner)
            .finish()
    }
}

/// Errors that may arise from a handshake
pub enum HandshakeError<S> {
    /// The handshake could not be started
    SetupFailure(ErrorStack),
    /// An error was encountered mid-handshake
    Failure(MidHandshakeSslStream<S>),
}

impl<S: fmt::Debug> fmt::Debug for HandshakeError<S> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            HandshakeError::SetupFailure(ref e) => f.debug_tuple("SetupFailure").field(e).finish(),
            HandshakeError::Failure(ref s) => f.debug_tuple("Failure").field(s).finish(),
        }
    }
}

impl<S: fmt::Debug> StdError for HandshakeError<S> {
    fn description(&self) -> &str {
        match *self {
            HandshakeError::SetupFailure(_) => "stream setup failed",
            HandshakeError::Failure(_) => "the handshake failed",
        }
    }

    fn source(&self) -> Option<&(dyn StdError + 'static)> {
        match *self {
            HandshakeError::SetupFailure(ref e) => Some(e),
            HandshakeError::Failure(ref s) => {
                s.0.inner.as_ref().map(|s| s.0.error() as &dyn StdError)
            }
        }
    }
}

impl<S: fmt::Debug> fmt::Display for HandshakeError<S> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.write_str(StdError::description(self))?;
        match *self {
            HandshakeError::SetupFailure(ref e) => write!(f, ": {}", e)?,
            HandshakeError::Failure(ref s) => {
                if let Some(s) = s.0.inner.as_ref().map(|s| &s.0) {
                    write!(f, ": {}", s.error())?;
                    let verify = s.ssl().verify_result();
                    if verify != X509VerifyResult::OK {
                        write!(f, ": {}", verify)?;
                    }
                }
            }
        }
        Ok(())
    }
}

impl<S> From<ErrorStack> for HandshakeError<S> {
    fn from(other: ErrorStack) -> Self {
        HandshakeError::SetupFailure(other)
    }
}

impl<S> HandshakeError<S> {
    pub(crate) fn from_ssl(err: ssl::HandshakeError<AsStdIo<S>>, ctrl: WakerCtrl) -> Option<Self> {
        Some(match err {
            ssl::HandshakeError::SetupFailure(e) => HandshakeError::SetupFailure(e),
            ssl::HandshakeError::Failure(inner) => {
                HandshakeError::Failure(MidHandshakeSslStream::new(inner, ctrl))
            }
            _ => return None,
        })
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshakeSslStream<S> {
    type Output = Result<SslStream<S>, HandshakeError<S>>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = Pin::get_mut(self);
        let mut handshake = &mut this.0;
        let (inner, ctrl) = handshake
            .inner
            .take()
            .expect("handshake polled after completion");

        ctrl.register(cx.waker());

        match inner.handshake() {
            Ok(inner) => Poll::Ready(Ok(SslStream { inner, ctrl })),
            Err(ssl::HandshakeError::WouldBlock(inner)) => {
                handshake.inner = Some((inner, ctrl));
                Poll::Pending
            }
            Err(err) => Poll::Ready(Err(HandshakeError::from_ssl(err, ctrl).unwrap())),
        }
    }
}