hyper_server/tls_rustls/
future.rs

1//! Module containing futures specific to the `rustls` TLS acceptor for the server.
2//!
3//! This module primarily provides the `RustlsAcceptorFuture` which is responsible for performing the TLS handshake
4//! using the `rustls` library.
5
6use crate::tls_rustls::RustlsConfig;
7use pin_project_lite::pin_project;
8use std::io::{Error, ErrorKind};
9use std::time::Duration;
10use std::{
11    fmt,
12    future::Future,
13    io,
14    pin::Pin,
15    task::{Context, Poll},
16};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio::time::{timeout, Timeout};
19use tokio_rustls::{server::TlsStream, Accept, TlsAcceptor};
20
21pin_project! {
22    /// A future representing the asynchronous TLS handshake using the `rustls` library.
23    ///
24    /// Once completed, it yields a `TlsStream` which is a wrapper around the actual underlying stream, with
25    /// encryption and decryption operations applied to it.
26    pub struct RustlsAcceptorFuture<F, I, S> {
27        #[pin]
28        inner: AcceptFuture<F, I, S>,
29        config: Option<RustlsConfig>,
30    }
31}
32
33impl<F, I, S> RustlsAcceptorFuture<F, I, S> {
34    /// Constructs a new `RustlsAcceptorFuture`.
35    ///
36    /// * `future`: The future that resolves to the original non-encrypted stream.
37    /// * `config`: The rustls configuration to use for the handshake.
38    /// * `handshake_timeout`: The maximum duration to wait for the handshake to complete.
39    pub(crate) fn new(future: F, config: RustlsConfig, handshake_timeout: Duration) -> Self {
40        let inner = AcceptFuture::Inner {
41            future,
42            handshake_timeout,
43        };
44        let config = Some(config);
45
46        Self { inner, config }
47    }
48}
49
50impl<F, I, S> fmt::Debug for RustlsAcceptorFuture<F, I, S> {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("RustlsAcceptorFuture").finish()
53    }
54}
55
56pin_project! {
57    /// Internal states of the handshake process.
58    #[project = AcceptFutureProj]
59    enum AcceptFuture<F, I, S> {
60        /// Initial state where we have a future that resolves to the original non-encrypted stream.
61        Inner {
62            #[pin]
63            future: F,
64            handshake_timeout: Duration,
65        },
66        /// State after receiving the stream where the handshake is performed asynchronously.
67        Accept {
68            #[pin]
69            future: Timeout<Accept<I>>,
70            service: Option<S>,
71        },
72    }
73}
74
75impl<F, I, S> Future for RustlsAcceptorFuture<F, I, S>
76where
77    F: Future<Output = io::Result<(I, S)>>,
78    I: AsyncRead + AsyncWrite + Unpin,
79{
80    type Output = io::Result<(TlsStream<I>, S)>;
81
82    /// Advances the handshake state machine.
83    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84        let mut this = self.project();
85
86        loop {
87            match this.inner.as_mut().project() {
88                AcceptFutureProj::Inner {
89                    future,
90                    handshake_timeout,
91                } => {
92                    // Poll the future to get the original stream.
93                    match future.poll(cx) {
94                        Poll::Ready(Ok((stream, service))) => {
95                            let server_config = this.config
96                                .take()
97                                .expect("config is not set. this is a bug in hyper-server, please report")
98                                .get_inner();
99
100                            let acceptor = TlsAcceptor::from(server_config);
101                            let future = acceptor.accept(stream);
102
103                            let service = Some(service);
104                            let handshake_timeout = *handshake_timeout;
105
106                            this.inner.set(AcceptFuture::Accept {
107                                future: timeout(handshake_timeout, future),
108                                service,
109                            });
110                        }
111                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
112                        Poll::Pending => return Poll::Pending,
113                    }
114                }
115                AcceptFutureProj::Accept { future, service } => match future.poll(cx) {
116                    Poll::Ready(Ok(Ok(stream))) => {
117                        let service = service.take().expect("future polled after ready");
118
119                        return Poll::Ready(Ok((stream, service)));
120                    }
121                    Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)),
122                    Poll::Ready(Err(timeout)) => {
123                        return Poll::Ready(Err(Error::new(ErrorKind::TimedOut, timeout)))
124                    }
125                    Poll::Pending => return Poll::Pending,
126                },
127            }
128        }
129    }
130}