actix_tls/accept/
rustls_0_22.rs

1//! `rustls` v0.22 based TLS connection acceptor service.
2//!
3//! See [`Acceptor`] for main service factory docs.
4
5use std::{
6    convert::Infallible,
7    future::Future,
8    io::{self, IoSlice},
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12    time::Duration,
13};
14
15use actix_rt::{
16    net::{ActixStream, Ready},
17    time::{sleep, Sleep},
18};
19use actix_service::{Service, ServiceFactory};
20use actix_utils::{
21    counter::{Counter, CounterGuard},
22    future::{ready, Ready as FutReady},
23};
24use pin_project_lite::pin_project;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio_rustls::{Accept, TlsAcceptor};
27use tokio_rustls_025 as tokio_rustls;
28
29use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
30
31pub mod reexports {
32    //! Re-exports from `rustls` that are useful for acceptors.
33
34    pub use tokio_rustls_025::rustls::ServerConfig;
35}
36
37/// Wraps a `rustls` based async TLS stream in order to implement [`ActixStream`].
38pub struct TlsStream<IO>(tokio_rustls::server::TlsStream<IO>);
39
40impl_more::impl_from!(<IO> in tokio_rustls::server::TlsStream<IO> => TlsStream<IO>);
41impl_more::impl_deref_and_mut!(<IO> in TlsStream<IO> => tokio_rustls::server::TlsStream<IO>);
42
43impl<IO: ActixStream> AsyncRead for TlsStream<IO> {
44    fn poll_read(
45        self: Pin<&mut Self>,
46        cx: &mut Context<'_>,
47        buf: &mut ReadBuf<'_>,
48    ) -> Poll<io::Result<()>> {
49        Pin::new(&mut **self.get_mut()).poll_read(cx, buf)
50    }
51}
52
53impl<IO: ActixStream> AsyncWrite for TlsStream<IO> {
54    fn poll_write(
55        self: Pin<&mut Self>,
56        cx: &mut Context<'_>,
57        buf: &[u8],
58    ) -> Poll<io::Result<usize>> {
59        Pin::new(&mut **self.get_mut()).poll_write(cx, buf)
60    }
61
62    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
63        Pin::new(&mut **self.get_mut()).poll_flush(cx)
64    }
65
66    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67        Pin::new(&mut **self.get_mut()).poll_shutdown(cx)
68    }
69
70    fn poll_write_vectored(
71        self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73        bufs: &[IoSlice<'_>],
74    ) -> Poll<io::Result<usize>> {
75        Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs)
76    }
77
78    fn is_write_vectored(&self) -> bool {
79        (**self).is_write_vectored()
80    }
81}
82
83impl<IO: ActixStream> ActixStream for TlsStream<IO> {
84    fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
85        IO::poll_read_ready((**self).get_ref().0, cx)
86    }
87
88    fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
89        IO::poll_write_ready((**self).get_ref().0, cx)
90    }
91}
92
93/// Accept TLS connections via the `rustls` crate.
94pub struct Acceptor {
95    config: Arc<reexports::ServerConfig>,
96    handshake_timeout: Duration,
97}
98
99impl Acceptor {
100    /// Constructs `rustls` based acceptor service factory.
101    pub fn new(config: reexports::ServerConfig) -> Self {
102        Acceptor {
103            config: Arc::new(config),
104            handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
105        }
106    }
107
108    /// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
109    ///
110    /// Default timeout is 3 seconds.
111    pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
112        self.handshake_timeout = handshake_timeout;
113        self
114    }
115}
116
117impl Clone for Acceptor {
118    fn clone(&self) -> Self {
119        Self {
120            config: self.config.clone(),
121            handshake_timeout: self.handshake_timeout,
122        }
123    }
124}
125
126impl<IO: ActixStream> ServiceFactory<IO> for Acceptor {
127    type Response = TlsStream<IO>;
128    type Error = TlsError<io::Error, Infallible>;
129    type Config = ();
130    type Service = AcceptorService;
131    type InitError = ();
132    type Future = FutReady<Result<Self::Service, Self::InitError>>;
133
134    fn new_service(&self, _: ()) -> Self::Future {
135        let res = MAX_CONN_COUNTER.with(|conns| {
136            Ok(AcceptorService {
137                acceptor: self.config.clone().into(),
138                conns: conns.clone(),
139                handshake_timeout: self.handshake_timeout,
140            })
141        });
142
143        ready(res)
144    }
145}
146
147/// Rustls based acceptor service.
148pub struct AcceptorService {
149    acceptor: TlsAcceptor,
150    conns: Counter,
151    handshake_timeout: Duration,
152}
153
154impl<IO: ActixStream> Service<IO> for AcceptorService {
155    type Response = TlsStream<IO>;
156    type Error = TlsError<io::Error, Infallible>;
157    type Future = AcceptFut<IO>;
158
159    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160        if self.conns.available(cx) {
161            Poll::Ready(Ok(()))
162        } else {
163            Poll::Pending
164        }
165    }
166
167    fn call(&self, req: IO) -> Self::Future {
168        AcceptFut {
169            fut: self.acceptor.accept(req),
170            timeout: sleep(self.handshake_timeout),
171            _guard: self.conns.get(),
172        }
173    }
174}
175
176pin_project! {
177    /// Accept future for Rustls service.
178    #[doc(hidden)]
179    pub struct AcceptFut<IO: ActixStream> {
180        fut: Accept<IO>,
181        #[pin]
182        timeout: Sleep,
183        _guard: CounterGuard,
184    }
185}
186
187impl<IO: ActixStream> Future for AcceptFut<IO> {
188    type Output = Result<TlsStream<IO>, TlsError<io::Error, Infallible>>;
189
190    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
191        let mut this = self.project();
192        match Pin::new(&mut this.fut).poll(cx) {
193            Poll::Ready(Ok(stream)) => Poll::Ready(Ok(TlsStream(stream))),
194            Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))),
195            Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)),
196        }
197    }
198}