openssl_async/
connect.rs

1use std::mem;
2use std::pin::Pin;
3
4use openssl::ssl::{self, ConnectConfiguration, SslConnector};
5
6use futures::io::{AsyncRead, AsyncWrite};
7use futures::prelude::*;
8use futures::task::{Context, Poll};
9
10use async_stdio::*;
11
12use crate::{HandshakeError, MidHandshakeSslStream, SslStream};
13
14/// Extension trait for [SslConnector] to allow connections to be initiated
15/// asynchronously.
16pub trait SslConnectorExt {
17    /// Asynchronously initiate the SSL connection
18    fn connect_async<S: AsyncRead + AsyncWrite>(&self, domain: &str, stream: S) -> ConnectAsync<S>;
19}
20
21/// Extension trait for [ConnectConfiguration] to allow connections to be
22/// initiated asynchronously.
23pub trait ConnectConfigurationExt {
24    /// Asynchronously initiate the SSL connection
25    fn connect_async<S: AsyncRead + AsyncWrite>(self, domain: &str, stream: S) -> ConnectAsync<S>;
26}
27
28impl ConnectConfigurationExt for ConnectConfiguration {
29    fn connect_async<S: AsyncRead + AsyncWrite>(self, domain: &str, stream: S) -> ConnectAsync<S> {
30        ConnectAsync(ConnectInner::Init(self, domain.into(), stream))
31    }
32}
33
34impl SslConnectorExt for SslConnector {
35    fn connect_async<S: AsyncRead + AsyncWrite>(&self, domain: &str, stream: S) -> ConnectAsync<S> {
36        match self.configure() {
37            Ok(s) => s.connect_async(domain, stream),
38            Err(e) => ConnectAsync(ConnectInner::Error(HandshakeError::SetupFailure(e))),
39        }
40    }
41}
42
43/// The future returned from [SslConnectorExt::connect_async]
44///
45/// Resolves to a [SslStream]
46pub struct ConnectAsync<S>(ConnectInner<S>);
47
48enum ConnectInner<S> {
49    Init(ConnectConfiguration, String, S),
50    Handshake(MidHandshakeSslStream<S>),
51    Error(HandshakeError<S>),
52    Done,
53}
54
55impl<S: AsyncRead + AsyncWrite + Unpin> Future for ConnectAsync<S> {
56    type Output = Result<SslStream<S>, HandshakeError<S>>;
57
58    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
59        let this = Pin::get_mut(self);
60
61        match mem::replace(&mut this.0, ConnectInner::Done) {
62            ConnectInner::Init(config, domain, stream) => {
63                let (stream, ctrl) = AsStdIo::new(stream, cx.waker().into());
64                match config.connect(&domain, stream) {
65                    Ok(inner) => Poll::Ready(Ok(SslStream { inner, ctrl })),
66                    Err(ssl::HandshakeError::WouldBlock(inner)) => {
67                        this.0 = ConnectInner::Handshake(MidHandshakeSslStream::new(inner, ctrl));
68                        Poll::Pending
69                    }
70                    Err(e) => Poll::Ready(Err(HandshakeError::from_ssl(e, ctrl).unwrap())),
71                }
72            }
73            ConnectInner::Handshake(mut handshake) => match Pin::new(&mut handshake).poll(cx) {
74                Poll::Ready(result) => Poll::Ready(result),
75                Poll::Pending => {
76                    this.0 = ConnectInner::Handshake(handshake);
77                    Poll::Pending
78                }
79            },
80            ConnectInner::Error(e) => Poll::Ready(Err(e)),
81            ConnectInner::Done => panic!("accept polled after completion"),
82        }
83    }
84}