actix_tls/accept/
native_tls.rs

1//! `native-tls` based TLS connection acceptor service.
2//!
3//! See [`Acceptor`] for main service factory docs.
4
5use std::{
6    convert::Infallible,
7    io::{self, IoSlice},
8    pin::Pin,
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use actix_rt::{
14    net::{ActixStream, Ready},
15    time::timeout,
16};
17use actix_service::{Service, ServiceFactory};
18use actix_utils::{
19    counter::Counter,
20    future::{ready, Ready as FutReady},
21};
22use futures_core::future::LocalBoxFuture;
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24use tokio_native_tls::{native_tls::Error, TlsAcceptor};
25
26use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
27
28pub mod reexports {
29    //! Re-exports from `native-tls` that are useful for acceptors.
30
31    pub use tokio_native_tls::{native_tls::Error, TlsAcceptor};
32}
33
34/// Wraps a `native-tls` based async TLS stream in order to implement [`ActixStream`].
35pub struct TlsStream<IO>(tokio_native_tls::TlsStream<IO>);
36
37impl_more::impl_from!(<IO> in tokio_native_tls::TlsStream<IO> => TlsStream<IO>);
38impl_more::impl_deref_and_mut!(<IO> in TlsStream<IO> => tokio_native_tls::TlsStream<IO>);
39
40impl<IO: ActixStream> AsyncRead for TlsStream<IO> {
41    fn poll_read(
42        self: Pin<&mut Self>,
43        cx: &mut Context<'_>,
44        buf: &mut ReadBuf<'_>,
45    ) -> Poll<io::Result<()>> {
46        Pin::new(&mut **self.get_mut()).poll_read(cx, buf)
47    }
48}
49
50impl<IO: ActixStream> AsyncWrite for TlsStream<IO> {
51    fn poll_write(
52        self: Pin<&mut Self>,
53        cx: &mut Context<'_>,
54        buf: &[u8],
55    ) -> Poll<io::Result<usize>> {
56        Pin::new(&mut **self.get_mut()).poll_write(cx, buf)
57    }
58
59    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
60        Pin::new(&mut **self.get_mut()).poll_flush(cx)
61    }
62
63    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
64        Pin::new(&mut **self.get_mut()).poll_shutdown(cx)
65    }
66
67    fn poll_write_vectored(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        bufs: &[IoSlice<'_>],
71    ) -> Poll<io::Result<usize>> {
72        Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs)
73    }
74
75    fn is_write_vectored(&self) -> bool {
76        (**self).is_write_vectored()
77    }
78}
79
80impl<IO: ActixStream> ActixStream for TlsStream<IO> {
81    fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
82        IO::poll_read_ready((**self).get_ref().get_ref().get_ref(), cx)
83    }
84
85    fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
86        IO::poll_write_ready((**self).get_ref().get_ref().get_ref(), cx)
87    }
88}
89
90/// Accept TLS connections via the `native-tls` crate.
91pub struct Acceptor {
92    acceptor: TlsAcceptor,
93    handshake_timeout: Duration,
94}
95
96impl Acceptor {
97    /// Constructs `native-tls` based acceptor service factory.
98    pub fn new(acceptor: TlsAcceptor) -> Self {
99        Acceptor {
100            acceptor,
101            handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
102        }
103    }
104
105    /// Limit the amount of time that the acceptor will wait for a TLS handshake to complete.
106    ///
107    /// Default timeout is 3 seconds.
108    pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
109        self.handshake_timeout = handshake_timeout;
110        self
111    }
112}
113
114impl Clone for Acceptor {
115    #[inline]
116    fn clone(&self) -> Self {
117        Self {
118            acceptor: self.acceptor.clone(),
119            handshake_timeout: self.handshake_timeout,
120        }
121    }
122}
123
124impl<IO: ActixStream + 'static> ServiceFactory<IO> for Acceptor {
125    type Response = TlsStream<IO>;
126    type Error = TlsError<Error, Infallible>;
127    type Config = ();
128    type Service = AcceptorService;
129    type InitError = ();
130    type Future = FutReady<Result<Self::Service, Self::InitError>>;
131
132    fn new_service(&self, _: ()) -> Self::Future {
133        let res = MAX_CONN_COUNTER.with(|conns| {
134            Ok(AcceptorService {
135                acceptor: self.acceptor.clone(),
136                conns: conns.clone(),
137                handshake_timeout: self.handshake_timeout,
138            })
139        });
140
141        ready(res)
142    }
143}
144
145/// Native-TLS based acceptor service.
146pub struct AcceptorService {
147    acceptor: TlsAcceptor,
148    conns: Counter,
149    handshake_timeout: Duration,
150}
151
152impl<IO: ActixStream + 'static> Service<IO> for AcceptorService {
153    type Response = TlsStream<IO>;
154    type Error = TlsError<Error, Infallible>;
155    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
156
157    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158        if self.conns.available(cx) {
159            Poll::Ready(Ok(()))
160        } else {
161            Poll::Pending
162        }
163    }
164
165    fn call(&self, io: IO) -> Self::Future {
166        let guard = self.conns.get();
167        let acceptor = self.acceptor.clone();
168
169        let dur = self.handshake_timeout;
170
171        Box::pin(async move {
172            match timeout(dur, acceptor.accept(io)).await {
173                Ok(Ok(io)) => {
174                    drop(guard);
175                    Ok(TlsStream(io))
176                }
177                Ok(Err(err)) => Err(TlsError::Tls(err)),
178                Err(_timeout) => Err(TlsError::Timeout),
179            }
180        })
181    }
182}