axum_server/tls_openssl/
future.rs

1//! Future types.
2
3use super::OpenSSLConfig;
4use pin_project_lite::pin_project;
5use std::io::{Error, ErrorKind};
6use std::time::Duration;
7use std::{
8    fmt,
9    future::Future,
10    io,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tokio::io::{AsyncRead, AsyncWrite};
15use tokio::time::{timeout, Timeout};
16
17use openssl::ssl::Ssl;
18use tokio_openssl::SslStream;
19
20pin_project! {
21    /// Future type for [`OpenSSLAcceptor`](crate::tls_openssl::OpenSSLAcceptor).
22    pub struct OpenSSLAcceptorFuture<F, I, S> {
23        #[pin]
24        inner: AcceptFuture<F, I, S>,
25        config: Option<OpenSSLConfig>,
26    }
27}
28
29impl<F, I, S> OpenSSLAcceptorFuture<F, I, S> {
30    pub(crate) fn new(future: F, config: OpenSSLConfig, handshake_timeout: Duration) -> Self {
31        let inner = AcceptFuture::InnerAccepting {
32            future,
33            handshake_timeout,
34        };
35        let config = Some(config);
36
37        Self { inner, config }
38    }
39}
40
41impl<F, I, S> fmt::Debug for OpenSSLAcceptorFuture<F, I, S> {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.debug_struct("OpenSSLAcceptorFuture").finish()
44    }
45}
46
47pin_project! {
48    struct TlsAccept<I> {
49        #[pin]
50        tls_stream: Option<SslStream<I>>,
51    }
52}
53
54impl<I> Future for TlsAccept<I>
55where
56    I: AsyncRead + AsyncWrite + Unpin,
57{
58    type Output = io::Result<SslStream<I>>;
59
60    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
61        let mut this = self.project();
62
63        match this
64            .tls_stream
65            .as_mut()
66            .as_pin_mut()
67            .map(|inner| inner.poll_accept(cx))
68            .expect("tlsaccept polled after ready")
69        {
70            Poll::Ready(Ok(())) => {
71                let tls_stream = this.tls_stream.take().expect("tls stream vanished?");
72
73                Poll::Ready(Ok(tls_stream))
74            }
75            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
76
77            Poll::Pending => Poll::Pending,
78        }
79    }
80}
81
82pin_project! {
83    #[project = AcceptFutureProj]
84    enum AcceptFuture<F, I, S> {
85        // We are waiting on the inner (lower) future to complete accept()
86        // so that we can begin installing TLS into the channel.
87        InnerAccepting {
88            #[pin]
89            future: F,
90            handshake_timeout: Duration,
91        },
92        // We are waiting for TLS to install into the channel so that we can
93        // proceed to return the SslStream.
94        TlsAccepting {
95            #[pin]
96            future: Timeout< TlsAccept<I> >,
97            service: Option<S>,
98        }
99    }
100}
101
102impl<F, I, S> Future for OpenSSLAcceptorFuture<F, I, S>
103where
104    F: Future<Output = io::Result<(I, S)>>,
105    I: AsyncRead + AsyncWrite + Unpin,
106{
107    type Output = io::Result<(SslStream<I>, S)>;
108
109    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
110        let mut this = self.project();
111
112        // The inner future here is what is doing the lower level accept, such as
113        // our tcp socket.
114        //
115        // So we poll on that first, when it's ready we then swap our the inner future to
116        // one waiting for our ssl layer to accept/install.
117        //
118        // Then once that's ready we can then wrap and provide the SslStream back out.
119
120        // This loop exists to allow the Poll::Ready from InnerAccept on complete
121        // to re-poll immediately. Otherwise all other paths are immediate returns.
122        loop {
123            match this.inner.as_mut().project() {
124                AcceptFutureProj::InnerAccepting {
125                    future,
126                    handshake_timeout,
127                } => match future.poll(cx) {
128                    Poll::Ready(Ok((stream, service))) => {
129                        let server_config = this.config.take().expect(
130                            "config is not set. this is a bug in axum-server, please report",
131                        );
132
133                        // Change to poll::ready(err)
134                        let ssl = Ssl::new(server_config.get_inner().context()).unwrap();
135
136                        let tls_stream = SslStream::new(ssl, stream).unwrap();
137                        let future = TlsAccept {
138                            tls_stream: Some(tls_stream),
139                        };
140
141                        let service = Some(service);
142                        let handshake_timeout = *handshake_timeout;
143
144                        this.inner.set(AcceptFuture::TlsAccepting {
145                            future: timeout(handshake_timeout, future),
146                            service,
147                        });
148                        // the loop is now triggered to immediately poll on
149                        // ssl stream accept.
150                    }
151                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
152                    Poll::Pending => return Poll::Pending,
153                },
154
155                AcceptFutureProj::TlsAccepting { future, service } => match future.poll(cx) {
156                    Poll::Ready(Ok(Ok(stream))) => {
157                        let service = service.take().expect("future polled after ready");
158
159                        return Poll::Ready(Ok((stream, service)));
160                    }
161                    Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)),
162                    Poll::Ready(Err(timeout)) => {
163                        return Poll::Ready(Err(Error::new(ErrorKind::TimedOut, timeout)));
164                    }
165                    Poll::Pending => return Poll::Pending,
166                },
167            }
168        }
169    }
170}