openssl_async/
handshake.rs

1use std::error::Error as StdError;
2use std::fmt;
3use std::pin::Pin;
4
5use openssl::error::ErrorStack;
6use openssl::ssl;
7use openssl::x509::X509VerifyResult;
8
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::prelude::*;
11use futures::task::{Context, Poll};
12
13use async_stdio::*;
14
15use crate::SslStream;
16
17/// An SSL stream mid-handshake
18///
19/// This wraps the [ssl::MidHandshakeSslStream] type to make it a
20/// [Future] that resolves to an [SslStream].
21pub struct MidHandshakeSslStream<S>(HandshakeInner<S>);
22
23impl<S: fmt::Debug> fmt::Debug for MidHandshakeSslStream<S> {
24    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25        f.debug_tuple("MidHandshakeSslStream")
26            .field(&self.0)
27            .finish()
28    }
29}
30
31impl<S> MidHandshakeSslStream<S> {
32    pub(crate) fn new(inner: ssl::MidHandshakeSslStream<AsStdIo<S>>, ctrl: WakerCtrl) -> Self {
33        MidHandshakeSslStream(HandshakeInner {
34            inner: Some((inner, ctrl)),
35        })
36    }
37}
38
39struct HandshakeInner<S> {
40    inner: Option<(ssl::MidHandshakeSslStream<AsStdIo<S>>, WakerCtrl)>,
41}
42
43impl<S: fmt::Debug> fmt::Debug for HandshakeInner<S> {
44    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45        f.debug_struct("SslStream")
46            .field("inner", &self.inner)
47            .finish()
48    }
49}
50
51/// Errors that may arise from a handshake
52pub enum HandshakeError<S> {
53    /// The handshake could not be started
54    SetupFailure(ErrorStack),
55    /// An error was encountered mid-handshake
56    Failure(MidHandshakeSslStream<S>),
57}
58
59impl<S: fmt::Debug> fmt::Debug for HandshakeError<S> {
60    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61        match self {
62            HandshakeError::SetupFailure(ref e) => f.debug_tuple("SetupFailure").field(e).finish(),
63            HandshakeError::Failure(ref s) => f.debug_tuple("Failure").field(s).finish(),
64        }
65    }
66}
67
68impl<S: fmt::Debug> StdError for HandshakeError<S> {
69    fn description(&self) -> &str {
70        match *self {
71            HandshakeError::SetupFailure(_) => "stream setup failed",
72            HandshakeError::Failure(_) => "the handshake failed",
73        }
74    }
75
76    fn source(&self) -> Option<&(dyn StdError + 'static)> {
77        match *self {
78            HandshakeError::SetupFailure(ref e) => Some(e),
79            HandshakeError::Failure(ref s) => {
80                s.0.inner.as_ref().map(|s| s.0.error() as &dyn StdError)
81            }
82        }
83    }
84}
85
86impl<S: fmt::Debug> fmt::Display for HandshakeError<S> {
87    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88        f.write_str(StdError::description(self))?;
89        match *self {
90            HandshakeError::SetupFailure(ref e) => write!(f, ": {}", e)?,
91            HandshakeError::Failure(ref s) => {
92                if let Some(s) = s.0.inner.as_ref().map(|s| &s.0) {
93                    write!(f, ": {}", s.error())?;
94                    let verify = s.ssl().verify_result();
95                    if verify != X509VerifyResult::OK {
96                        write!(f, ": {}", verify)?;
97                    }
98                }
99            }
100        }
101        Ok(())
102    }
103}
104
105impl<S> From<ErrorStack> for HandshakeError<S> {
106    fn from(other: ErrorStack) -> Self {
107        HandshakeError::SetupFailure(other)
108    }
109}
110
111impl<S> HandshakeError<S> {
112    pub(crate) fn from_ssl(err: ssl::HandshakeError<AsStdIo<S>>, ctrl: WakerCtrl) -> Option<Self> {
113        Some(match err {
114            ssl::HandshakeError::SetupFailure(e) => HandshakeError::SetupFailure(e),
115            ssl::HandshakeError::Failure(inner) => {
116                HandshakeError::Failure(MidHandshakeSslStream::new(inner, ctrl))
117            }
118            _ => return None,
119        })
120    }
121}
122
123impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshakeSslStream<S> {
124    type Output = Result<SslStream<S>, HandshakeError<S>>;
125
126    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        let this = Pin::get_mut(self);
128        let mut handshake = &mut this.0;
129        let (inner, ctrl) = handshake
130            .inner
131            .take()
132            .expect("handshake polled after completion");
133
134        ctrl.register(cx.waker());
135
136        match inner.handshake() {
137            Ok(inner) => Poll::Ready(Ok(SslStream { inner, ctrl })),
138            Err(ssl::HandshakeError::WouldBlock(inner)) => {
139                handshake.inner = Some((inner, ctrl));
140                Poll::Pending
141            }
142            Err(err) => Poll::Ready(Err(HandshakeError::from_ssl(err, ctrl).unwrap())),
143        }
144    }
145}