openssl_async/
handshake.rs1use std::error::Error as StdError;
2use std::fmt;
3use std::pin::Pin;
4
5use openssl::error::ErrorStack;
6use openssl::ssl;
7use openssl::x509::X509VerifyResult;
8
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::prelude::*;
11use futures::task::{Context, Poll};
12
13use async_stdio::*;
14
15use crate::SslStream;
16
17pub struct MidHandshakeSslStream<S>(HandshakeInner<S>);
22
23impl<S: fmt::Debug> fmt::Debug for MidHandshakeSslStream<S> {
24 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25 f.debug_tuple("MidHandshakeSslStream")
26 .field(&self.0)
27 .finish()
28 }
29}
30
31impl<S> MidHandshakeSslStream<S> {
32 pub(crate) fn new(inner: ssl::MidHandshakeSslStream<AsStdIo<S>>, ctrl: WakerCtrl) -> Self {
33 MidHandshakeSslStream(HandshakeInner {
34 inner: Some((inner, ctrl)),
35 })
36 }
37}
38
39struct HandshakeInner<S> {
40 inner: Option<(ssl::MidHandshakeSslStream<AsStdIo<S>>, WakerCtrl)>,
41}
42
43impl<S: fmt::Debug> fmt::Debug for HandshakeInner<S> {
44 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45 f.debug_struct("SslStream")
46 .field("inner", &self.inner)
47 .finish()
48 }
49}
50
51pub enum HandshakeError<S> {
53 SetupFailure(ErrorStack),
55 Failure(MidHandshakeSslStream<S>),
57}
58
59impl<S: fmt::Debug> fmt::Debug for HandshakeError<S> {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 match self {
62 HandshakeError::SetupFailure(ref e) => f.debug_tuple("SetupFailure").field(e).finish(),
63 HandshakeError::Failure(ref s) => f.debug_tuple("Failure").field(s).finish(),
64 }
65 }
66}
67
68impl<S: fmt::Debug> StdError for HandshakeError<S> {
69 fn description(&self) -> &str {
70 match *self {
71 HandshakeError::SetupFailure(_) => "stream setup failed",
72 HandshakeError::Failure(_) => "the handshake failed",
73 }
74 }
75
76 fn source(&self) -> Option<&(dyn StdError + 'static)> {
77 match *self {
78 HandshakeError::SetupFailure(ref e) => Some(e),
79 HandshakeError::Failure(ref s) => {
80 s.0.inner.as_ref().map(|s| s.0.error() as &dyn StdError)
81 }
82 }
83 }
84}
85
86impl<S: fmt::Debug> fmt::Display for HandshakeError<S> {
87 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88 f.write_str(StdError::description(self))?;
89 match *self {
90 HandshakeError::SetupFailure(ref e) => write!(f, ": {}", e)?,
91 HandshakeError::Failure(ref s) => {
92 if let Some(s) = s.0.inner.as_ref().map(|s| &s.0) {
93 write!(f, ": {}", s.error())?;
94 let verify = s.ssl().verify_result();
95 if verify != X509VerifyResult::OK {
96 write!(f, ": {}", verify)?;
97 }
98 }
99 }
100 }
101 Ok(())
102 }
103}
104
105impl<S> From<ErrorStack> for HandshakeError<S> {
106 fn from(other: ErrorStack) -> Self {
107 HandshakeError::SetupFailure(other)
108 }
109}
110
111impl<S> HandshakeError<S> {
112 pub(crate) fn from_ssl(err: ssl::HandshakeError<AsStdIo<S>>, ctrl: WakerCtrl) -> Option<Self> {
113 Some(match err {
114 ssl::HandshakeError::SetupFailure(e) => HandshakeError::SetupFailure(e),
115 ssl::HandshakeError::Failure(inner) => {
116 HandshakeError::Failure(MidHandshakeSslStream::new(inner, ctrl))
117 }
118 _ => return None,
119 })
120 }
121}
122
123impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshakeSslStream<S> {
124 type Output = Result<SslStream<S>, HandshakeError<S>>;
125
126 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 let this = Pin::get_mut(self);
128 let mut handshake = &mut this.0;
129 let (inner, ctrl) = handshake
130 .inner
131 .take()
132 .expect("handshake polled after completion");
133
134 ctrl.register(cx.waker());
135
136 match inner.handshake() {
137 Ok(inner) => Poll::Ready(Ok(SslStream { inner, ctrl })),
138 Err(ssl::HandshakeError::WouldBlock(inner)) => {
139 handshake.inner = Some((inner, ctrl));
140 Poll::Pending
141 }
142 Err(err) => Poll::Ready(Err(HandshakeError::from_ssl(err, ctrl).unwrap())),
143 }
144 }
145}