1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use std::fmt;
use std::pin::Pin;

use openssl::error::ErrorStack;
use openssl::ssl;

use futures::io::{AsyncRead, AsyncWrite};
use futures::prelude::*;
use futures::task::{Context, Poll};

use async_stdio::*;

use crate::SslStream;

/// An SSL stream mid-handshake
///
/// This wraps the [ssl::MidHandshakeSslStream] type to make it a
/// [Future] that resolves to an [SslStream].
pub struct MidHandshakeSslStream<S>(HandshakeInner<S>);

impl<S: fmt::Debug> fmt::Debug for MidHandshakeSslStream<S> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_tuple("MidHandshakeSslStream")
            .field(&self.0)
            .finish()
    }
}

impl<S> MidHandshakeSslStream<S> {
    pub(crate) fn new(
        inner: ssl::MidHandshakeSslStream<AsStdIo<S>>,
        ctrl: WakerCtrlHandle,
    ) -> Self {
        MidHandshakeSslStream(HandshakeInner {
            inner: Some(inner),
            ctrl,
        })
    }
}

struct HandshakeInner<S> {
    inner: Option<ssl::MidHandshakeSslStream<AsStdIo<S>>>,
    ctrl: WakerCtrlHandle,
}

impl<S: fmt::Debug> fmt::Debug for HandshakeInner<S> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_struct("SslStream")
            .field("inner", &self.inner)
            .finish()
    }
}

/// Errors that may arise from a handshake
pub enum HandshakeError<S> {
    /// The handshake could not be started
    SetupFailure(ErrorStack),
    /// An error was encountered mid-handshake
    Failure(MidHandshakeSslStream<S>),
}

impl<S: fmt::Debug> fmt::Debug for HandshakeError<S> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            HandshakeError::SetupFailure(e) => f.debug_tuple("SetupFailure").field(&e).finish(),
            HandshakeError::Failure(s) => f.debug_tuple("Failure").field(&s).finish(),
        }
    }
}

impl<S> HandshakeError<S> {
    pub(crate) fn from_ssl(
        err: ssl::HandshakeError<AsStdIo<S>>,
        ctrl: WakerCtrlHandle,
    ) -> Option<Self> {
        Some(match err {
            ssl::HandshakeError::SetupFailure(e) => HandshakeError::SetupFailure(e),
            ssl::HandshakeError::Failure(inner) => {
                HandshakeError::Failure(MidHandshakeSslStream::new(inner, ctrl))
            }
            _ => return None,
        })
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshakeSslStream<S> {
    type Output = Result<SslStream<S>, HandshakeError<S>>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = Pin::get_mut(self);
        let mut handshake = &mut this.0;
        let inner = handshake
            .inner
            .take()
            .expect("handshake polled after completion");

        handshake.ctrl.register(cx.waker());

        match inner.handshake() {
            Ok(inner) => Poll::Ready(Ok(SslStream {
                inner,
                ctrl: handshake.ctrl.clone(),
            })),
            Err(ssl::HandshakeError::WouldBlock(inner)) => {
                handshake.inner = Some(inner);
                Poll::Pending
            }
            Err(err) => Poll::Ready(Err(
                HandshakeError::from_ssl(err, handshake.ctrl.clone()).unwrap()
            )),
        }
    }
}