1use crate::{
2 common::tokio_stream::TokioListenerStream, ConnectionError, LocalAddress, Preview,
3 PreviewConfiguration, ResolvedTarget, RewindStream, Ssl, StreamUpgrade, TlsDriver,
4 TlsServerParameterProvider, UpgradableStream, DEFAULT_TLS_BACKLOG,
5};
6use futures::{stream::FuturesUnordered, StreamExt};
7use std::{
8 future::Future,
9 pin::Pin,
10 task::{ready, Poll},
11};
12use std::{net::SocketAddr, path::Path};
13use tokio::io::AsyncReadExt;
14
15type Connection<D = Ssl> = UpgradableStream<crate::BaseStream, D>;
16
17pub struct Acceptor<const PREVIEW: bool = false> {
18 resolved_target: ResolvedTarget,
19 tls_provider: Option<TlsServerParameterProvider>,
20 should_upgrade: bool,
21 options: StreamOptions<PREVIEW>,
22}
23
24#[derive(Debug, Clone, Copy)]
25struct StreamOptions<const PREVIEW: bool> {
26 ignore_missing_tls_close_notify: bool,
27 preview_configuration: Option<PreviewConfiguration>,
28 tcp_backlog: Option<u32>,
29 tls_backlog: Option<u32>,
30}
31
32impl<const PREVIEW: bool> Default for StreamOptions<PREVIEW> {
33 fn default() -> Self {
34 Self {
35 ignore_missing_tls_close_notify: false,
36 preview_configuration: None,
37 tcp_backlog: None,
38 tls_backlog: None,
39 }
40 }
41}
42
43impl Acceptor<false> {
44 pub fn new(target: ResolvedTarget) -> Self {
45 Self {
46 resolved_target: target,
47 tls_provider: None,
48 should_upgrade: false,
49 options: Default::default(),
50 }
51 }
52
53 pub fn new_tls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
54 Self {
55 resolved_target: target,
56 tls_provider: Some(provider),
57 should_upgrade: true,
58 options: Default::default(),
59 }
60 }
61
62 pub fn new_starttls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
63 Self {
64 resolved_target: target,
65 tls_provider: Some(provider),
66 should_upgrade: false,
67 options: Default::default(),
68 }
69 }
70
71 pub fn new_tcp(addr: SocketAddr) -> Self {
72 Self {
73 resolved_target: ResolvedTarget::SocketAddr(addr),
74 tls_provider: None,
75 should_upgrade: false,
76 options: Default::default(),
77 }
78 }
79
80 pub fn new_tcp_tls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
81 Self {
82 resolved_target: ResolvedTarget::SocketAddr(addr),
83 tls_provider: Some(provider),
84 should_upgrade: true,
85 options: Default::default(),
86 }
87 }
88
89 pub fn new_tcp_starttls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
90 Self {
91 resolved_target: ResolvedTarget::SocketAddr(addr),
92 tls_provider: Some(provider),
93 should_upgrade: false,
94 options: Default::default(),
95 }
96 }
97
98 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
99 #[cfg(unix)]
100 {
101 Ok(Self {
102 resolved_target: ResolvedTarget::from(
103 std::os::unix::net::SocketAddr::from_pathname(path)?,
104 ),
105 tls_provider: None,
106 should_upgrade: false,
107 options: Default::default(),
108 })
109 }
110 #[cfg(not(unix))]
111 {
112 Err(std::io::Error::new(
113 std::io::ErrorKind::Unsupported,
114 "Unix domain sockets are not supported on this platform",
115 ))
116 }
117 }
118
119 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
120 #[cfg(any(target_os = "linux", target_os = "android"))]
121 {
122 use std::os::linux::net::SocketAddrExt;
123 Ok(Self {
124 resolved_target: ResolvedTarget::from(
125 std::os::unix::net::SocketAddr::from_abstract_name(domain)?,
126 ),
127 tls_provider: None,
128 should_upgrade: false,
129 options: Default::default(),
130 })
131 }
132 #[cfg(not(any(target_os = "linux", target_os = "android")))]
133 {
134 Err(std::io::Error::new(
135 std::io::ErrorKind::Unsupported,
136 "Unix domain sockets are not supported on this platform",
137 ))
138 }
139 }
140
141 pub async fn bind(
142 self,
143 ) -> Result<
144 impl ::futures::Stream<Item = Result<Connection, ConnectionError>> + LocalAddress,
145 ConnectionError,
146 > {
147 let stream = self
148 .resolved_target
149 .listen_raw(self.options.tcp_backlog)
150 .await?;
151 Ok(AcceptedStream::<Connection<Ssl>> {
152 stream,
153 should_upgrade: self.should_upgrade,
154 ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
155 tls_provider: self.tls_provider,
156 tls_backlog: TlsAcceptBacklog::new(
157 self.options.tls_backlog.unwrap_or(DEFAULT_TLS_BACKLOG) as _,
158 ),
159 preview_configuration: None,
160 _phantom: None,
161 })
162 }
163
164 #[allow(private_bounds)]
165 pub async fn bind_explicit<D: TlsDriver>(
166 self,
167 ) -> Result<
168 impl ::futures::Stream<Item = Result<Connection<D>, ConnectionError>> + LocalAddress,
169 ConnectionError,
170 > {
171 let stream = self
172 .resolved_target
173 .listen_raw(self.options.tcp_backlog)
174 .await?;
175 Ok(AcceptedStream::<Connection<D>, D> {
176 stream,
177 ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
178 should_upgrade: self.should_upgrade,
179 tls_provider: self.tls_provider,
180 tls_backlog: TlsAcceptBacklog::new(
181 self.options.tls_backlog.unwrap_or(DEFAULT_TLS_BACKLOG) as _,
182 ),
183 preview_configuration: None,
184 _phantom: None,
185 })
186 }
187
188 pub async fn accept_one(self) -> Result<Connection, ConnectionError> {
190 let Some(conn) = self.bind().await?.next().await else {
191 return Err(ConnectionError::Io(std::io::Error::new(
192 std::io::ErrorKind::Interrupted,
193 "No connection received",
194 )));
195 };
196 conn
197 }
198}
199
200impl Acceptor<true> {
201 pub fn new_tcp_tls_previewing(
204 addr: SocketAddr,
205 preview_configuration: PreviewConfiguration,
206 provider: TlsServerParameterProvider,
207 ) -> Self {
208 Self {
209 resolved_target: ResolvedTarget::SocketAddr(addr),
210 tls_provider: Some(provider),
211 should_upgrade: false,
212 options: StreamOptions {
213 preview_configuration: Some(preview_configuration),
214 ..Default::default()
215 },
216 }
217 }
218
219 pub fn new_tls_previewing(
222 addr: ResolvedTarget,
223 preview_configuration: PreviewConfiguration,
224 provider: TlsServerParameterProvider,
225 ) -> Self {
226 Self {
227 resolved_target: addr,
228 tls_provider: Some(provider),
229 should_upgrade: false,
230 options: StreamOptions {
231 preview_configuration: Some(preview_configuration),
232 ..Default::default()
233 },
234 }
235 }
236
237 pub fn new_previewing(
240 addr: ResolvedTarget,
241 preview_configuration: PreviewConfiguration,
242 ) -> Self {
243 Self {
244 resolved_target: addr,
245 tls_provider: None,
246 should_upgrade: false,
247 options: StreamOptions {
248 preview_configuration: Some(preview_configuration),
249 ..Default::default()
250 },
251 }
252 }
253
254 pub async fn bind(
255 self,
256 ) -> Result<
257 impl ::futures::Stream<Item = Result<(Preview, Connection), ConnectionError>> + LocalAddress,
258 ConnectionError,
259 > {
260 let stream = self
261 .resolved_target
262 .listen_raw(self.options.tcp_backlog)
263 .await?;
264 Ok(AcceptedStream::<(Preview, Connection<Ssl>)> {
265 stream,
266 should_upgrade: self.should_upgrade,
267 ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
268 tls_provider: self.tls_provider,
269 tls_backlog: TlsAcceptBacklog::new(self.options.tls_backlog.unwrap_or(128) as _),
270 preview_configuration: self.options.preview_configuration,
271 _phantom: None,
272 })
273 }
274
275 #[allow(private_bounds)]
276 pub async fn bind_explicit<D: TlsDriver>(
277 self,
278 ) -> Result<
279 impl ::futures::Stream<Item = Result<(Preview, Connection<D>), ConnectionError>> + LocalAddress,
280 ConnectionError,
281 > {
282 let stream = self
283 .resolved_target
284 .listen_raw(self.options.tcp_backlog)
285 .await?;
286 Ok(AcceptedStream::<(Preview, Connection<D>), D> {
287 stream,
288 should_upgrade: self.should_upgrade,
289 ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
290 tls_provider: self.tls_provider,
291 tls_backlog: TlsAcceptBacklog::new(
292 self.options.tls_backlog.unwrap_or(DEFAULT_TLS_BACKLOG) as _,
293 ),
294 preview_configuration: self.options.preview_configuration,
295 _phantom: None,
296 })
297 }
298
299 pub async fn accept_one(self) -> Result<(Preview, Connection), ConnectionError> {
301 let Some(conn) = self.bind().await?.next().await else {
302 return Err(ConnectionError::Io(std::io::Error::new(
303 std::io::ErrorKind::Interrupted,
304 "No connection received",
305 )));
306 };
307 conn
308 }
309}
310
311struct AcceptedStream<S, D: TlsDriver = Ssl> {
312 stream: TokioListenerStream,
313 should_upgrade: bool,
314 ignore_missing_tls_close_notify: bool,
315 tls_provider: Option<TlsServerParameterProvider>,
316 tls_backlog: TlsAcceptBacklog<S>,
317 preview_configuration: Option<PreviewConfiguration>,
318 _phantom: Option<&'static D>,
320}
321
322impl<S, D: TlsDriver> LocalAddress for AcceptedStream<S, D> {
323 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
324 self.stream.local_address()
325 }
326}
327
328impl<D: TlsDriver> futures::Stream for AcceptedStream<Connection<D>, D> {
329 type Item = Result<Connection<D>, ConnectionError>;
330
331 fn poll_next(
332 mut self: std::pin::Pin<&mut Self>,
333 cx: &mut std::task::Context<'_>,
334 ) -> Poll<Option<Self::Item>> {
335 let ignore_missing_tls_close_notify = self.ignore_missing_tls_close_notify;
336 let make_stream = move |tls_provider: Option<TlsServerParameterProvider>, stream| {
337 let mut stream = UpgradableStream::<_, D>::new_server(stream, tls_provider);
338 if ignore_missing_tls_close_notify {
339 stream.ignore_missing_close_notify();
340 }
341 stream
342 };
343
344 if !self.should_upgrade {
347 return self.as_mut().stream.poll_next_unpin(cx).map(|c| {
348 c.map(|c| Ok(c.map(|(c, _t)| make_stream(self.tls_provider.clone(), c))?))
349 });
350 }
351
352 while !self.tls_backlog.is_full() {
354 let Poll::Ready(r) = self.stream.poll_next_unpin(cx) else {
355 if self.tls_backlog.is_empty() {
356 return Poll::Pending;
357 }
358 break;
359 };
360
361 let Some((stream, _t)) = r.transpose()? else {
362 if self.tls_backlog.is_empty() {
363 return Poll::Ready(None);
364 }
365 break;
366 };
367
368 let tls_provider = self.tls_provider.clone();
369 self.tls_backlog.push(async move {
370 let stream = make_stream(tls_provider, stream);
371 let stream = stream.secure_upgrade().await?;
372 Ok(stream)
373 })
374 }
375
376 debug_assert!(!self.tls_backlog.is_empty());
378 let r = ready!(Pin::new(&mut self.tls_backlog).poll_next(cx))?;
379 Poll::Ready(Some(Ok(r)))
380 }
381}
382
383impl<D: TlsDriver> futures::Stream for AcceptedStream<(Preview, Connection<D>), D> {
384 type Item = Result<(Preview, Connection<D>), ConnectionError>;
385 fn poll_next(
386 mut self: std::pin::Pin<&mut Self>,
387 cx: &mut std::task::Context<'_>,
388 ) -> Poll<Option<Self::Item>> {
389 while !self.tls_backlog.is_full() {
391 let Poll::Ready(r) = self.stream.poll_next_unpin(cx) else {
392 if self.tls_backlog.is_empty() {
393 return Poll::Pending;
394 }
395 break;
396 };
397
398 let Some((mut stream, _t)) = r.transpose()? else {
399 if self.tls_backlog.is_empty() {
400 return Poll::Ready(None);
401 }
402 break;
403 };
404
405 let tls_provider = self.tls_provider.clone();
406 let preview_configuration = self.preview_configuration.unwrap();
407 let ignore_missing_tls_close_notify = self.ignore_missing_tls_close_notify;
408 self.tls_backlog.push(async move {
409 let mut buf = smallvec::SmallVec::with_capacity(
410 preview_configuration.max_preview_bytes.get(),
411 );
412 buf.resize(preview_configuration.max_preview_bytes.get(), 0);
413 stream.read_exact(&mut buf).await?;
414 let mut stream = RewindStream::new(stream);
415 stream.rewind(&buf);
416 let preview = Preview::new(buf);
417 let mut stream = UpgradableStream::<_, D>::new_server_preview(stream, tls_provider);
418 if ignore_missing_tls_close_notify {
419 stream.ignore_missing_close_notify();
420 }
421
422 Ok((preview, stream))
423 })
424 }
425
426 debug_assert!(!self.tls_backlog.is_empty());
428 let r = ready!(Pin::new(&mut self.tls_backlog).poll_next(cx))?;
429 Poll::Ready(Some(Ok(r)))
430 }
431}
432
433struct TlsAcceptBacklog<C> {
434 capacity: usize,
435 #[allow(clippy::type_complexity)]
436 futures: FuturesUnordered<
437 Pin<Box<dyn Future<Output = Result<C, ConnectionError>> + Send + 'static>>,
438 >,
439}
440
441impl<C> TlsAcceptBacklog<C> {
442 fn new(capacity: usize) -> Self {
443 Self {
444 capacity,
445 futures: FuturesUnordered::new(),
446 }
447 }
448
449 fn is_full(&self) -> bool {
450 self.futures.len() >= self.capacity
451 }
452
453 fn is_empty(&self) -> bool {
454 self.futures.len() == 0
455 }
456
457 fn poll_next(
458 mut self: std::pin::Pin<&mut Self>,
459 cx: &mut std::task::Context<'_>,
460 ) -> Poll<Result<C, ConnectionError>> {
461 debug_assert!(!self.is_empty());
462 self.futures.poll_next_unpin(cx).map(|r| r.unwrap())
463 }
464
465 fn push(&mut self, future: impl Future<Output = Result<C, ConnectionError>> + Send + 'static) {
466 self.futures.push(Box::pin(future));
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::{
474 Connector, OpensslDriver, RustlsDriver, Target, TlsParameters, TlsServerParameters,
475 };
476 use std::net::*;
477 use tokio::io::{AsyncReadExt, AsyncWriteExt};
478
479 async fn test_acceptor_new_tcp_previewing<D: TlsDriver>() -> Result<(), ConnectionError> {
480 let acceptor = Acceptor::new_tcp_tls_previewing(
481 SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
482 PreviewConfiguration::default(),
483 TlsServerParameterProvider::new(TlsServerParameters::new_with_certificate(
484 crate::test_keys::SERVER_KEY.clone_key(),
485 )),
486 );
487
488 let mut conns = acceptor.bind_explicit::<D>().await?;
489
490 let addr = conns.local_address()?;
491 tokio::task::spawn(async move {
492 let mut conn = Connector::new_resolved(addr).connect().await?;
493 conn.write_all(b"HELLO WORLD").await
494 });
495
496 let (preview, mut conn) = conns.next().await.unwrap()?;
497 assert_eq!(preview.len(), 8);
498 assert_eq!(preview, b"HELLO WO");
499 let mut string = String::new();
500 conn.read_to_string(&mut string).await?;
501 assert_eq!(string, "HELLO WORLD");
502
503 let addr = conns.local_address()?;
504 tokio::task::spawn(async move {
505 let target = Target::new_resolved_tls(addr, TlsParameters::insecure());
506 let mut conn = Connector::new(target)?.connect().await?;
507 conn.write_all(b"HELLO WORLD").await
508 });
509
510 let (preview, conn) = conns.next().await.unwrap()?;
511 assert_eq!(preview.len(), 8);
512 assert!(matches!(preview.as_ref(), [0x16, 3, 1, ..]));
513 let (preview, mut conn) = conn
514 .secure_upgrade_preview(PreviewConfiguration::default())
515 .await?;
516 assert_eq!(preview.len(), 8);
517 assert_eq!(preview, b"HELLO WO");
518
519 let mut string = String::new();
520 conn.read_to_string(&mut string).await?;
521 assert_eq!(string, "HELLO WORLD");
522
523 Ok(())
524 }
525
526 #[tokio::test]
527 async fn test_acceptor_new_tcp_previewing_openssl() -> Result<(), ConnectionError> {
528 test_acceptor_new_tcp_previewing::<OpensslDriver>().await
529 }
530
531 #[tokio::test]
532 async fn test_acceptor_new_tcp_previewing_rustls() -> Result<(), ConnectionError> {
533 test_acceptor_new_tcp_previewing::<RustlsDriver>().await
534 }
535}