1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
use std::error::Error as StdError;
use std::fmt;
use std::pin::Pin;
use openssl::error::ErrorStack;
use openssl::ssl;
use openssl::x509::X509VerifyResult;
use futures::io::{AsyncRead, AsyncWrite};
use futures::prelude::*;
use futures::task::{Context, Poll};
use async_stdio::*;
use crate::SslStream;
pub struct MidHandshakeSslStream<S>(HandshakeInner<S>);
impl<S: fmt::Debug> fmt::Debug for MidHandshakeSslStream<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("MidHandshakeSslStream")
.field(&self.0)
.finish()
}
}
impl<S> MidHandshakeSslStream<S> {
pub(crate) fn new(inner: ssl::MidHandshakeSslStream<AsStdIo<S>>, ctrl: WakerCtrl) -> Self {
MidHandshakeSslStream(HandshakeInner {
inner: Some((inner, ctrl)),
})
}
}
struct HandshakeInner<S> {
inner: Option<(ssl::MidHandshakeSslStream<AsStdIo<S>>, WakerCtrl)>,
}
impl<S: fmt::Debug> fmt::Debug for HandshakeInner<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SslStream")
.field("inner", &self.inner)
.finish()
}
}
pub enum HandshakeError<S> {
SetupFailure(ErrorStack),
Failure(MidHandshakeSslStream<S>),
}
impl<S: fmt::Debug> fmt::Debug for HandshakeError<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
HandshakeError::SetupFailure(ref e) => f.debug_tuple("SetupFailure").field(e).finish(),
HandshakeError::Failure(ref s) => f.debug_tuple("Failure").field(s).finish(),
}
}
}
impl<S: fmt::Debug> StdError for HandshakeError<S> {
fn description(&self) -> &str {
match *self {
HandshakeError::SetupFailure(_) => "stream setup failed",
HandshakeError::Failure(_) => "the handshake failed",
}
}
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match *self {
HandshakeError::SetupFailure(ref e) => Some(e),
HandshakeError::Failure(ref s) => {
s.0.inner.as_ref().map(|s| s.0.error() as &dyn StdError)
}
}
}
}
impl<S: fmt::Debug> fmt::Display for HandshakeError<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(StdError::description(self))?;
match *self {
HandshakeError::SetupFailure(ref e) => write!(f, ": {}", e)?,
HandshakeError::Failure(ref s) => {
if let Some(s) = s.0.inner.as_ref().map(|s| &s.0) {
write!(f, ": {}", s.error())?;
let verify = s.ssl().verify_result();
if verify != X509VerifyResult::OK {
write!(f, ": {}", verify)?;
}
}
}
}
Ok(())
}
}
impl<S> From<ErrorStack> for HandshakeError<S> {
fn from(other: ErrorStack) -> Self {
HandshakeError::SetupFailure(other)
}
}
impl<S> HandshakeError<S> {
pub(crate) fn from_ssl(err: ssl::HandshakeError<AsStdIo<S>>, ctrl: WakerCtrl) -> Option<Self> {
Some(match err {
ssl::HandshakeError::SetupFailure(e) => HandshakeError::SetupFailure(e),
ssl::HandshakeError::Failure(inner) => {
HandshakeError::Failure(MidHandshakeSslStream::new(inner, ctrl))
}
_ => return None,
})
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshakeSslStream<S> {
type Output = Result<SslStream<S>, HandshakeError<S>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::get_mut(self);
let mut handshake = &mut this.0;
let (inner, ctrl) = handshake
.inner
.take()
.expect("handshake polled after completion");
ctrl.register(cx.waker());
match inner.handshake() {
Ok(inner) => Poll::Ready(Ok(SslStream { inner, ctrl: ctrl })),
Err(ssl::HandshakeError::WouldBlock(inner)) => {
handshake.inner = Some((inner, ctrl));
Poll::Pending
}
Err(err) => Poll::Ready(Err(HandshakeError::from_ssl(err, ctrl).unwrap())),
}
}
}