1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use std::fmt;
use std::pin::Pin;
use openssl::error::ErrorStack;
use openssl::ssl;
use futures::io::{AsyncRead, AsyncWrite};
use futures::prelude::*;
use futures::task::{Context, Poll};
use async_stdio::*;
use crate::SslStream;
pub struct MidHandshakeSslStream<S>(HandshakeInner<S>);
impl<S: fmt::Debug> fmt::Debug for MidHandshakeSslStream<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("MidHandshakeSslStream")
.field(&self.0)
.finish()
}
}
impl<S> MidHandshakeSslStream<S> {
pub(crate) fn new(
inner: ssl::MidHandshakeSslStream<AsStdIo<S>>,
ctrl: WakerCtrlHandle,
) -> Self {
MidHandshakeSslStream(HandshakeInner {
inner: Some(inner),
ctrl,
})
}
}
struct HandshakeInner<S> {
inner: Option<ssl::MidHandshakeSslStream<AsStdIo<S>>>,
ctrl: WakerCtrlHandle,
}
impl<S: fmt::Debug> fmt::Debug for HandshakeInner<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SslStream")
.field("inner", &self.inner)
.finish()
}
}
pub enum HandshakeError<S> {
SetupFailure(ErrorStack),
Failure(MidHandshakeSslStream<S>),
}
impl<S: fmt::Debug> fmt::Debug for HandshakeError<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
HandshakeError::SetupFailure(e) => f.debug_tuple("SetupFailure").field(&e).finish(),
HandshakeError::Failure(s) => f.debug_tuple("Failure").field(&s).finish(),
}
}
}
impl<S> HandshakeError<S> {
pub(crate) fn from_ssl(
err: ssl::HandshakeError<AsStdIo<S>>,
ctrl: WakerCtrlHandle,
) -> Option<Self> {
Some(match err {
ssl::HandshakeError::SetupFailure(e) => HandshakeError::SetupFailure(e),
ssl::HandshakeError::Failure(inner) => {
HandshakeError::Failure(MidHandshakeSslStream::new(inner, ctrl))
}
_ => return None,
})
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshakeSslStream<S> {
type Output = Result<SslStream<S>, HandshakeError<S>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::get_mut(self);
let mut handshake = &mut this.0;
let inner = handshake
.inner
.take()
.expect("handshake polled after completion");
handshake.ctrl.register(cx.waker());
match inner.handshake() {
Ok(inner) => Poll::Ready(Ok(SslStream {
inner,
ctrl: handshake.ctrl.clone(),
})),
Err(ssl::HandshakeError::WouldBlock(inner)) => {
handshake.inner = Some(inner);
Poll::Pending
}
Err(err) => Poll::Ready(Err(
HandshakeError::from_ssl(err, handshake.ctrl.clone()).unwrap()
)),
}
}
}