hyper_serve/tls_openssl/
future.rs

1//! Future types.
2
3use super::OpenSSLConfig;
4use pin_project_lite::pin_project;
5use std::io::{Error, ErrorKind};
6use std::time::Duration;
7use std::{
8    fmt,
9    future::Future,
10    io,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tokio::io::{AsyncRead, AsyncWrite};
15use tokio::time::{timeout, Timeout};
16
17use openssl::ssl::Ssl;
18use tokio_openssl::SslStream;
19
20pin_project! {
21    /// Future type for [`OpenSSLAcceptor`](crate::tls_openssl::OpenSSLAcceptor).
22    pub struct OpenSSLAcceptorFuture<F, I, S> {
23        #[pin]
24        inner: AcceptFuture<F, I, S>,
25        config: Option<OpenSSLConfig>,
26    }
27}
28
29impl<F, I, S> OpenSSLAcceptorFuture<F, I, S> {
30    pub(crate) fn new(future: F, config: OpenSSLConfig, handshake_timeout: Duration) -> Self {
31        let inner = AcceptFuture::InnerAccepting {
32            future,
33            handshake_timeout,
34        };
35        let config = Some(config);
36
37        Self { inner, config }
38    }
39}
40
41impl<F, I, S> fmt::Debug for OpenSSLAcceptorFuture<F, I, S> {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.debug_struct("OpenSSLAcceptorFuture").finish()
44    }
45}
46
47pin_project! {
48    struct TlsAccept<I> {
49        #[pin]
50        tls_stream: Option<SslStream<I>>,
51    }
52}
53
54impl<I> Future for TlsAccept<I>
55where
56    I: AsyncRead + AsyncWrite + Unpin,
57{
58    type Output = io::Result<SslStream<I>>;
59
60    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
61        let mut this = self.project();
62
63        match this
64            .tls_stream
65            .as_mut()
66            .as_pin_mut()
67            .map(|inner| inner.poll_accept(cx))
68            .expect("tlsaccept polled after ready")
69        {
70            Poll::Ready(Ok(())) => {
71                let tls_stream = this.tls_stream.take().expect("tls stream vanished?");
72
73                Poll::Ready(Ok(tls_stream))
74            }
75            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::new(ErrorKind::Other, e))),
76            Poll::Pending => Poll::Pending,
77        }
78    }
79}
80
81pin_project! {
82    #[project = AcceptFutureProj]
83    enum AcceptFuture<F, I, S> {
84        // We are waiting on the inner (lower) future to complete accept()
85        // so that we can begin installing TLS into the channel.
86        InnerAccepting {
87            #[pin]
88            future: F,
89            handshake_timeout: Duration,
90        },
91        // We are waiting for TLS to install into the channel so that we can
92        // proceed to return the SslStream.
93        TlsAccepting {
94            #[pin]
95            future: Timeout< TlsAccept<I> >,
96            service: Option<S>,
97        }
98    }
99}
100
101impl<F, I, S> Future for OpenSSLAcceptorFuture<F, I, S>
102where
103    F: Future<Output = io::Result<(I, S)>>,
104    I: AsyncRead + AsyncWrite + Unpin,
105{
106    type Output = io::Result<(SslStream<I>, S)>;
107
108    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109        let mut this = self.project();
110
111        // The inner future here is what is doing the lower level accept, such as
112        // our tcp socket.
113        //
114        // So we poll on that first, when it's ready we then swap our the inner future to
115        // one waiting for our ssl layer to accept/install.
116        //
117        // Then once that's ready we can then wrap and provide the SslStream back out.
118
119        // This loop exists to allow the Poll::Ready from InnerAccept on complete
120        // to re-poll immediately. Otherwise all other paths are immediate returns.
121        loop {
122            match this.inner.as_mut().project() {
123                AcceptFutureProj::InnerAccepting {
124                    future,
125                    handshake_timeout,
126                } => match future.poll(cx) {
127                    Poll::Ready(Ok((stream, service))) => {
128                        let server_config = this.config.take().expect(
129                            "config is not set. this is a bug in hyper-serve, please report",
130                        );
131
132                        // Change to poll::ready(err)
133                        let ssl = Ssl::new(server_config.get_inner().context()).unwrap();
134
135                        let tls_stream = SslStream::new(ssl, stream).unwrap();
136                        let future = TlsAccept {
137                            tls_stream: Some(tls_stream),
138                        };
139
140                        let service = Some(service);
141                        let handshake_timeout = *handshake_timeout;
142
143                        this.inner.set(AcceptFuture::TlsAccepting {
144                            future: timeout(handshake_timeout, future),
145                            service,
146                        });
147                        // the loop is now triggered to immediately poll on
148                        // ssl stream accept.
149                    }
150                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
151                    Poll::Pending => return Poll::Pending,
152                },
153
154                AcceptFutureProj::TlsAccepting { future, service } => match future.poll(cx) {
155                    Poll::Ready(Ok(Ok(stream))) => {
156                        let service = service.take().expect("future polled after ready");
157
158                        return Poll::Ready(Ok((stream, service)));
159                    }
160                    Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)),
161                    Poll::Ready(Err(timeout)) => {
162                        return Poll::Ready(Err(Error::new(ErrorKind::TimedOut, timeout)))
163                    }
164                    Poll::Pending => return Poll::Pending,
165                },
166            }
167        }
168    }
169}