actix_tls/accept/
native_tls.rs1use std::{
6 convert::Infallible,
7 io::{self, IoSlice},
8 pin::Pin,
9 task::{Context, Poll},
10 time::Duration,
11};
12
13use actix_rt::{
14 net::{ActixStream, Ready},
15 time::timeout,
16};
17use actix_service::{Service, ServiceFactory};
18use actix_utils::{
19 counter::Counter,
20 future::{ready, Ready as FutReady},
21};
22use futures_core::future::LocalBoxFuture;
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24use tokio_native_tls::{native_tls::Error, TlsAcceptor};
25
26use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
27
28pub mod reexports {
29 pub use tokio_native_tls::{native_tls::Error, TlsAcceptor};
32}
33
34pub struct TlsStream<IO>(tokio_native_tls::TlsStream<IO>);
36
37impl_more::impl_from!(<IO> in tokio_native_tls::TlsStream<IO> => TlsStream<IO>);
38impl_more::impl_deref_and_mut!(<IO> in TlsStream<IO> => tokio_native_tls::TlsStream<IO>);
39
40impl<IO: ActixStream> AsyncRead for TlsStream<IO> {
41 fn poll_read(
42 self: Pin<&mut Self>,
43 cx: &mut Context<'_>,
44 buf: &mut ReadBuf<'_>,
45 ) -> Poll<io::Result<()>> {
46 Pin::new(&mut **self.get_mut()).poll_read(cx, buf)
47 }
48}
49
50impl<IO: ActixStream> AsyncWrite for TlsStream<IO> {
51 fn poll_write(
52 self: Pin<&mut Self>,
53 cx: &mut Context<'_>,
54 buf: &[u8],
55 ) -> Poll<io::Result<usize>> {
56 Pin::new(&mut **self.get_mut()).poll_write(cx, buf)
57 }
58
59 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
60 Pin::new(&mut **self.get_mut()).poll_flush(cx)
61 }
62
63 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
64 Pin::new(&mut **self.get_mut()).poll_shutdown(cx)
65 }
66
67 fn poll_write_vectored(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 bufs: &[IoSlice<'_>],
71 ) -> Poll<io::Result<usize>> {
72 Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs)
73 }
74
75 fn is_write_vectored(&self) -> bool {
76 (**self).is_write_vectored()
77 }
78}
79
80impl<IO: ActixStream> ActixStream for TlsStream<IO> {
81 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
82 IO::poll_read_ready((**self).get_ref().get_ref().get_ref(), cx)
83 }
84
85 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
86 IO::poll_write_ready((**self).get_ref().get_ref().get_ref(), cx)
87 }
88}
89
90pub struct Acceptor {
92 acceptor: TlsAcceptor,
93 handshake_timeout: Duration,
94}
95
96impl Acceptor {
97 pub fn new(acceptor: TlsAcceptor) -> Self {
99 Acceptor {
100 acceptor,
101 handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
102 }
103 }
104
105 pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
109 self.handshake_timeout = handshake_timeout;
110 self
111 }
112}
113
114impl Clone for Acceptor {
115 #[inline]
116 fn clone(&self) -> Self {
117 Self {
118 acceptor: self.acceptor.clone(),
119 handshake_timeout: self.handshake_timeout,
120 }
121 }
122}
123
124impl<IO: ActixStream + 'static> ServiceFactory<IO> for Acceptor {
125 type Response = TlsStream<IO>;
126 type Error = TlsError<Error, Infallible>;
127 type Config = ();
128 type Service = AcceptorService;
129 type InitError = ();
130 type Future = FutReady<Result<Self::Service, Self::InitError>>;
131
132 fn new_service(&self, _: ()) -> Self::Future {
133 let res = MAX_CONN_COUNTER.with(|conns| {
134 Ok(AcceptorService {
135 acceptor: self.acceptor.clone(),
136 conns: conns.clone(),
137 handshake_timeout: self.handshake_timeout,
138 })
139 });
140
141 ready(res)
142 }
143}
144
145pub struct AcceptorService {
147 acceptor: TlsAcceptor,
148 conns: Counter,
149 handshake_timeout: Duration,
150}
151
152impl<IO: ActixStream + 'static> Service<IO> for AcceptorService {
153 type Response = TlsStream<IO>;
154 type Error = TlsError<Error, Infallible>;
155 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
156
157 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158 if self.conns.available(cx) {
159 Poll::Ready(Ok(()))
160 } else {
161 Poll::Pending
162 }
163 }
164
165 fn call(&self, io: IO) -> Self::Future {
166 let guard = self.conns.get();
167 let acceptor = self.acceptor.clone();
168
169 let dur = self.handshake_timeout;
170
171 Box::pin(async move {
172 match timeout(dur, acceptor.accept(io)).await {
173 Ok(Ok(io)) => {
174 drop(guard);
175 Ok(TlsStream(io))
176 }
177 Ok(Err(err)) => Err(TlsError::Tls(err)),
178 Err(_timeout) => Err(TlsError::Timeout),
179 }
180 })
181 }
182}