actix_tls/accept/
rustls_0_22.rs1use std::{
6 convert::Infallible,
7 future::Future,
8 io::{self, IoSlice},
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use actix_rt::{
16 net::{ActixStream, Ready},
17 time::{sleep, Sleep},
18};
19use actix_service::{Service, ServiceFactory};
20use actix_utils::{
21 counter::{Counter, CounterGuard},
22 future::{ready, Ready as FutReady},
23};
24use pin_project_lite::pin_project;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio_rustls::{Accept, TlsAcceptor};
27use tokio_rustls_025 as tokio_rustls;
28
29use super::{TlsError, DEFAULT_TLS_HANDSHAKE_TIMEOUT, MAX_CONN_COUNTER};
30
31pub mod reexports {
32 pub use tokio_rustls_025::rustls::ServerConfig;
35}
36
37pub struct TlsStream<IO>(tokio_rustls::server::TlsStream<IO>);
39
40impl_more::impl_from!(<IO> in tokio_rustls::server::TlsStream<IO> => TlsStream<IO>);
41impl_more::impl_deref_and_mut!(<IO> in TlsStream<IO> => tokio_rustls::server::TlsStream<IO>);
42
43impl<IO: ActixStream> AsyncRead for TlsStream<IO> {
44 fn poll_read(
45 self: Pin<&mut Self>,
46 cx: &mut Context<'_>,
47 buf: &mut ReadBuf<'_>,
48 ) -> Poll<io::Result<()>> {
49 Pin::new(&mut **self.get_mut()).poll_read(cx, buf)
50 }
51}
52
53impl<IO: ActixStream> AsyncWrite for TlsStream<IO> {
54 fn poll_write(
55 self: Pin<&mut Self>,
56 cx: &mut Context<'_>,
57 buf: &[u8],
58 ) -> Poll<io::Result<usize>> {
59 Pin::new(&mut **self.get_mut()).poll_write(cx, buf)
60 }
61
62 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
63 Pin::new(&mut **self.get_mut()).poll_flush(cx)
64 }
65
66 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67 Pin::new(&mut **self.get_mut()).poll_shutdown(cx)
68 }
69
70 fn poll_write_vectored(
71 self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 bufs: &[IoSlice<'_>],
74 ) -> Poll<io::Result<usize>> {
75 Pin::new(&mut **self.get_mut()).poll_write_vectored(cx, bufs)
76 }
77
78 fn is_write_vectored(&self) -> bool {
79 (**self).is_write_vectored()
80 }
81}
82
83impl<IO: ActixStream> ActixStream for TlsStream<IO> {
84 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
85 IO::poll_read_ready((**self).get_ref().0, cx)
86 }
87
88 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
89 IO::poll_write_ready((**self).get_ref().0, cx)
90 }
91}
92
93pub struct Acceptor {
95 config: Arc<reexports::ServerConfig>,
96 handshake_timeout: Duration,
97}
98
99impl Acceptor {
100 pub fn new(config: reexports::ServerConfig) -> Self {
102 Acceptor {
103 config: Arc::new(config),
104 handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
105 }
106 }
107
108 pub fn set_handshake_timeout(&mut self, handshake_timeout: Duration) -> &mut Self {
112 self.handshake_timeout = handshake_timeout;
113 self
114 }
115}
116
117impl Clone for Acceptor {
118 fn clone(&self) -> Self {
119 Self {
120 config: self.config.clone(),
121 handshake_timeout: self.handshake_timeout,
122 }
123 }
124}
125
126impl<IO: ActixStream> ServiceFactory<IO> for Acceptor {
127 type Response = TlsStream<IO>;
128 type Error = TlsError<io::Error, Infallible>;
129 type Config = ();
130 type Service = AcceptorService;
131 type InitError = ();
132 type Future = FutReady<Result<Self::Service, Self::InitError>>;
133
134 fn new_service(&self, _: ()) -> Self::Future {
135 let res = MAX_CONN_COUNTER.with(|conns| {
136 Ok(AcceptorService {
137 acceptor: self.config.clone().into(),
138 conns: conns.clone(),
139 handshake_timeout: self.handshake_timeout,
140 })
141 });
142
143 ready(res)
144 }
145}
146
147pub struct AcceptorService {
149 acceptor: TlsAcceptor,
150 conns: Counter,
151 handshake_timeout: Duration,
152}
153
154impl<IO: ActixStream> Service<IO> for AcceptorService {
155 type Response = TlsStream<IO>;
156 type Error = TlsError<io::Error, Infallible>;
157 type Future = AcceptFut<IO>;
158
159 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160 if self.conns.available(cx) {
161 Poll::Ready(Ok(()))
162 } else {
163 Poll::Pending
164 }
165 }
166
167 fn call(&self, req: IO) -> Self::Future {
168 AcceptFut {
169 fut: self.acceptor.accept(req),
170 timeout: sleep(self.handshake_timeout),
171 _guard: self.conns.get(),
172 }
173 }
174}
175
176pin_project! {
177 #[doc(hidden)]
179 pub struct AcceptFut<IO: ActixStream> {
180 fut: Accept<IO>,
181 #[pin]
182 timeout: Sleep,
183 _guard: CounterGuard,
184 }
185}
186
187impl<IO: ActixStream> Future for AcceptFut<IO> {
188 type Output = Result<TlsStream<IO>, TlsError<io::Error, Infallible>>;
189
190 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
191 let mut this = self.project();
192 match Pin::new(&mut this.fut).poll(cx) {
193 Poll::Ready(Ok(stream)) => Poll::Ready(Ok(TlsStream(stream))),
194 Poll::Ready(Err(err)) => Poll::Ready(Err(TlsError::Tls(err))),
195 Poll::Pending => this.timeout.poll(cx).map(|_| Err(TlsError::Timeout)),
196 }
197 }
198}