gel_stream/server/
acceptor.rs

1use crate::{
2    common::tokio_stream::TokioListenerStream, ConnectionError, LocalAddress, Preview,
3    PreviewConfiguration, ResolvedTarget, RewindStream, Ssl, StreamUpgrade, TlsDriver,
4    TlsServerParameterProvider, UpgradableStream, DEFAULT_TLS_BACKLOG,
5};
6use futures::{stream::FuturesUnordered, StreamExt};
7use std::{
8    future::Future,
9    pin::Pin,
10    task::{ready, Poll},
11};
12use std::{net::SocketAddr, path::Path};
13use tokio::io::AsyncReadExt;
14
15type Connection<D = Ssl> = UpgradableStream<crate::BaseStream, D>;
16
17pub struct Acceptor<const PREVIEW: bool = false> {
18    resolved_target: ResolvedTarget,
19    tls_provider: Option<TlsServerParameterProvider>,
20    should_upgrade: bool,
21    options: StreamOptions<PREVIEW>,
22}
23
24#[derive(Debug, Clone, Copy)]
25struct StreamOptions<const PREVIEW: bool> {
26    ignore_missing_tls_close_notify: bool,
27    preview_configuration: Option<PreviewConfiguration>,
28    tcp_backlog: Option<u32>,
29    tls_backlog: Option<u32>,
30}
31
32impl<const PREVIEW: bool> Default for StreamOptions<PREVIEW> {
33    fn default() -> Self {
34        Self {
35            ignore_missing_tls_close_notify: false,
36            preview_configuration: None,
37            tcp_backlog: None,
38            tls_backlog: None,
39        }
40    }
41}
42
43impl Acceptor<false> {
44    pub fn new(target: ResolvedTarget) -> Self {
45        Self {
46            resolved_target: target,
47            tls_provider: None,
48            should_upgrade: false,
49            options: Default::default(),
50        }
51    }
52
53    pub fn new_tls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
54        Self {
55            resolved_target: target,
56            tls_provider: Some(provider),
57            should_upgrade: true,
58            options: Default::default(),
59        }
60    }
61
62    pub fn new_starttls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
63        Self {
64            resolved_target: target,
65            tls_provider: Some(provider),
66            should_upgrade: false,
67            options: Default::default(),
68        }
69    }
70
71    pub fn new_tcp(addr: SocketAddr) -> Self {
72        Self {
73            resolved_target: ResolvedTarget::SocketAddr(addr),
74            tls_provider: None,
75            should_upgrade: false,
76            options: Default::default(),
77        }
78    }
79
80    pub fn new_tcp_tls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
81        Self {
82            resolved_target: ResolvedTarget::SocketAddr(addr),
83            tls_provider: Some(provider),
84            should_upgrade: true,
85            options: Default::default(),
86        }
87    }
88
89    pub fn new_tcp_starttls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
90        Self {
91            resolved_target: ResolvedTarget::SocketAddr(addr),
92            tls_provider: Some(provider),
93            should_upgrade: false,
94            options: Default::default(),
95        }
96    }
97
98    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
99        #[cfg(unix)]
100        {
101            Ok(Self {
102                resolved_target: ResolvedTarget::from(
103                    std::os::unix::net::SocketAddr::from_pathname(path)?,
104                ),
105                tls_provider: None,
106                should_upgrade: false,
107                options: Default::default(),
108            })
109        }
110        #[cfg(not(unix))]
111        {
112            Err(std::io::Error::new(
113                std::io::ErrorKind::Unsupported,
114                "Unix domain sockets are not supported on this platform",
115            ))
116        }
117    }
118
119    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
120        #[cfg(any(target_os = "linux", target_os = "android"))]
121        {
122            use std::os::linux::net::SocketAddrExt;
123            Ok(Self {
124                resolved_target: ResolvedTarget::from(
125                    std::os::unix::net::SocketAddr::from_abstract_name(domain)?,
126                ),
127                tls_provider: None,
128                should_upgrade: false,
129                options: Default::default(),
130            })
131        }
132        #[cfg(not(any(target_os = "linux", target_os = "android")))]
133        {
134            Err(std::io::Error::new(
135                std::io::ErrorKind::Unsupported,
136                "Unix domain sockets are not supported on this platform",
137            ))
138        }
139    }
140
141    pub async fn bind(
142        self,
143    ) -> Result<
144        impl ::futures::Stream<Item = Result<Connection, ConnectionError>> + LocalAddress,
145        ConnectionError,
146    > {
147        let stream = self
148            .resolved_target
149            .listen_raw(self.options.tcp_backlog)
150            .await?;
151        Ok(AcceptedStream::<Connection<Ssl>> {
152            stream,
153            should_upgrade: self.should_upgrade,
154            ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
155            tls_provider: self.tls_provider,
156            tls_backlog: TlsAcceptBacklog::new(
157                self.options.tls_backlog.unwrap_or(DEFAULT_TLS_BACKLOG) as _,
158            ),
159            preview_configuration: None,
160            _phantom: None,
161        })
162    }
163
164    #[allow(private_bounds)]
165    pub async fn bind_explicit<D: TlsDriver>(
166        self,
167    ) -> Result<
168        impl ::futures::Stream<Item = Result<Connection<D>, ConnectionError>> + LocalAddress,
169        ConnectionError,
170    > {
171        let stream = self
172            .resolved_target
173            .listen_raw(self.options.tcp_backlog)
174            .await?;
175        Ok(AcceptedStream::<Connection<D>, D> {
176            stream,
177            ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
178            should_upgrade: self.should_upgrade,
179            tls_provider: self.tls_provider,
180            tls_backlog: TlsAcceptBacklog::new(
181                self.options.tls_backlog.unwrap_or(DEFAULT_TLS_BACKLOG) as _,
182            ),
183            preview_configuration: None,
184            _phantom: None,
185        })
186    }
187
188    /// Listen, and then accept one and only one connection from the listener.
189    pub async fn accept_one(self) -> Result<Connection, ConnectionError> {
190        let Some(conn) = self.bind().await?.next().await else {
191            return Err(ConnectionError::Io(std::io::Error::new(
192                std::io::ErrorKind::Interrupted,
193                "No connection received",
194            )));
195        };
196        conn
197    }
198}
199
200impl Acceptor<true> {
201    /// Create a new TCP/TLS acceptor that will preview the first
202    /// [`PreviewConfiguration::max_preview_bytes`] bytes of the connection.
203    pub fn new_tcp_tls_previewing(
204        addr: SocketAddr,
205        preview_configuration: PreviewConfiguration,
206        provider: TlsServerParameterProvider,
207    ) -> Self {
208        Self {
209            resolved_target: ResolvedTarget::SocketAddr(addr),
210            tls_provider: Some(provider),
211            should_upgrade: false,
212            options: StreamOptions {
213                preview_configuration: Some(preview_configuration),
214                ..Default::default()
215            },
216        }
217    }
218
219    /// Create a new acceptor that will preview the first
220    /// [`PreviewConfiguration::max_preview_bytes`] bytes of the connection.
221    pub fn new_tls_previewing(
222        addr: ResolvedTarget,
223        preview_configuration: PreviewConfiguration,
224        provider: TlsServerParameterProvider,
225    ) -> Self {
226        Self {
227            resolved_target: addr,
228            tls_provider: Some(provider),
229            should_upgrade: false,
230            options: StreamOptions {
231                preview_configuration: Some(preview_configuration),
232                ..Default::default()
233            },
234        }
235    }
236
237    /// Create a new acceptor that will preview the first
238    /// [`PreviewConfiguration::max_preview_bytes`] bytes of the connection.
239    pub fn new_previewing(
240        addr: ResolvedTarget,
241        preview_configuration: PreviewConfiguration,
242    ) -> Self {
243        Self {
244            resolved_target: addr,
245            tls_provider: None,
246            should_upgrade: false,
247            options: StreamOptions {
248                preview_configuration: Some(preview_configuration),
249                ..Default::default()
250            },
251        }
252    }
253
254    pub async fn bind(
255        self,
256    ) -> Result<
257        impl ::futures::Stream<Item = Result<(Preview, Connection), ConnectionError>> + LocalAddress,
258        ConnectionError,
259    > {
260        let stream = self
261            .resolved_target
262            .listen_raw(self.options.tcp_backlog)
263            .await?;
264        Ok(AcceptedStream::<(Preview, Connection<Ssl>)> {
265            stream,
266            should_upgrade: self.should_upgrade,
267            ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
268            tls_provider: self.tls_provider,
269            tls_backlog: TlsAcceptBacklog::new(self.options.tls_backlog.unwrap_or(128) as _),
270            preview_configuration: self.options.preview_configuration,
271            _phantom: None,
272        })
273    }
274
275    #[allow(private_bounds)]
276    pub async fn bind_explicit<D: TlsDriver>(
277        self,
278    ) -> Result<
279        impl ::futures::Stream<Item = Result<(Preview, Connection<D>), ConnectionError>> + LocalAddress,
280        ConnectionError,
281    > {
282        let stream = self
283            .resolved_target
284            .listen_raw(self.options.tcp_backlog)
285            .await?;
286        Ok(AcceptedStream::<(Preview, Connection<D>), D> {
287            stream,
288            should_upgrade: self.should_upgrade,
289            ignore_missing_tls_close_notify: self.options.ignore_missing_tls_close_notify,
290            tls_provider: self.tls_provider,
291            tls_backlog: TlsAcceptBacklog::new(
292                self.options.tls_backlog.unwrap_or(DEFAULT_TLS_BACKLOG) as _,
293            ),
294            preview_configuration: self.options.preview_configuration,
295            _phantom: None,
296        })
297    }
298
299    /// Listen, and then accept one and only one connection from the listener.
300    pub async fn accept_one(self) -> Result<(Preview, Connection), ConnectionError> {
301        let Some(conn) = self.bind().await?.next().await else {
302            return Err(ConnectionError::Io(std::io::Error::new(
303                std::io::ErrorKind::Interrupted,
304                "No connection received",
305            )));
306        };
307        conn
308    }
309}
310
311struct AcceptedStream<S, D: TlsDriver = Ssl> {
312    stream: TokioListenerStream,
313    should_upgrade: bool,
314    ignore_missing_tls_close_notify: bool,
315    tls_provider: Option<TlsServerParameterProvider>,
316    tls_backlog: TlsAcceptBacklog<S>,
317    preview_configuration: Option<PreviewConfiguration>,
318    // Avoid using PhantomData because it fails to implement certain auto-traits
319    _phantom: Option<&'static D>,
320}
321
322impl<S, D: TlsDriver> LocalAddress for AcceptedStream<S, D> {
323    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
324        self.stream.local_address()
325    }
326}
327
328impl<D: TlsDriver> futures::Stream for AcceptedStream<Connection<D>, D> {
329    type Item = Result<Connection<D>, ConnectionError>;
330
331    fn poll_next(
332        mut self: std::pin::Pin<&mut Self>,
333        cx: &mut std::task::Context<'_>,
334    ) -> Poll<Option<Self::Item>> {
335        let ignore_missing_tls_close_notify = self.ignore_missing_tls_close_notify;
336        let make_stream = move |tls_provider: Option<TlsServerParameterProvider>, stream| {
337            let mut stream = UpgradableStream::<_, D>::new_server(stream, tls_provider);
338            if ignore_missing_tls_close_notify {
339                stream.ignore_missing_close_notify();
340            }
341            stream
342        };
343
344        // If we're not upgrading, we can just return the stream as is and skip
345        // the second-level backlog.
346        if !self.should_upgrade {
347            return self.as_mut().stream.poll_next_unpin(cx).map(|c| {
348                c.map(|c| Ok(c.map(|(c, _t)| make_stream(self.tls_provider.clone(), c))?))
349            });
350        }
351
352        // Fill the backlog to capacity as log as we have connections to accept.
353        while !self.tls_backlog.is_full() {
354            let Poll::Ready(r) = self.stream.poll_next_unpin(cx) else {
355                if self.tls_backlog.is_empty() {
356                    return Poll::Pending;
357                }
358                break;
359            };
360
361            let Some((stream, _t)) = r.transpose()? else {
362                if self.tls_backlog.is_empty() {
363                    return Poll::Ready(None);
364                }
365                break;
366            };
367
368            let tls_provider = self.tls_provider.clone();
369            self.tls_backlog.push(async move {
370                let stream = make_stream(tls_provider, stream);
371                let stream = stream.secure_upgrade().await?;
372                Ok(stream)
373            })
374        }
375
376        // We've got at least one pending connection here
377        debug_assert!(!self.tls_backlog.is_empty());
378        let r = ready!(Pin::new(&mut self.tls_backlog).poll_next(cx))?;
379        Poll::Ready(Some(Ok(r)))
380    }
381}
382
383impl<D: TlsDriver> futures::Stream for AcceptedStream<(Preview, Connection<D>), D> {
384    type Item = Result<(Preview, Connection<D>), ConnectionError>;
385    fn poll_next(
386        mut self: std::pin::Pin<&mut Self>,
387        cx: &mut std::task::Context<'_>,
388    ) -> Poll<Option<Self::Item>> {
389        // Fill the backlog to capacity as log as we have connections to accept.
390        while !self.tls_backlog.is_full() {
391            let Poll::Ready(r) = self.stream.poll_next_unpin(cx) else {
392                if self.tls_backlog.is_empty() {
393                    return Poll::Pending;
394                }
395                break;
396            };
397
398            let Some((mut stream, _t)) = r.transpose()? else {
399                if self.tls_backlog.is_empty() {
400                    return Poll::Ready(None);
401                }
402                break;
403            };
404
405            let tls_provider = self.tls_provider.clone();
406            let preview_configuration = self.preview_configuration.unwrap();
407            let ignore_missing_tls_close_notify = self.ignore_missing_tls_close_notify;
408            self.tls_backlog.push(async move {
409                let mut buf = smallvec::SmallVec::with_capacity(
410                    preview_configuration.max_preview_bytes.get(),
411                );
412                buf.resize(preview_configuration.max_preview_bytes.get(), 0);
413                stream.read_exact(&mut buf).await?;
414                let mut stream = RewindStream::new(stream);
415                stream.rewind(&buf);
416                let preview = Preview::new(buf);
417                let mut stream = UpgradableStream::<_, D>::new_server_preview(stream, tls_provider);
418                if ignore_missing_tls_close_notify {
419                    stream.ignore_missing_close_notify();
420                }
421
422                Ok((preview, stream))
423            })
424        }
425
426        // We've got at least one pending connection here
427        debug_assert!(!self.tls_backlog.is_empty());
428        let r = ready!(Pin::new(&mut self.tls_backlog).poll_next(cx))?;
429        Poll::Ready(Some(Ok(r)))
430    }
431}
432
433struct TlsAcceptBacklog<C> {
434    capacity: usize,
435    #[allow(clippy::type_complexity)]
436    futures: FuturesUnordered<
437        Pin<Box<dyn Future<Output = Result<C, ConnectionError>> + Send + 'static>>,
438    >,
439}
440
441impl<C> TlsAcceptBacklog<C> {
442    fn new(capacity: usize) -> Self {
443        Self {
444            capacity,
445            futures: FuturesUnordered::new(),
446        }
447    }
448
449    fn is_full(&self) -> bool {
450        self.futures.len() >= self.capacity
451    }
452
453    fn is_empty(&self) -> bool {
454        self.futures.len() == 0
455    }
456
457    fn poll_next(
458        mut self: std::pin::Pin<&mut Self>,
459        cx: &mut std::task::Context<'_>,
460    ) -> Poll<Result<C, ConnectionError>> {
461        debug_assert!(!self.is_empty());
462        self.futures.poll_next_unpin(cx).map(|r| r.unwrap())
463    }
464
465    fn push(&mut self, future: impl Future<Output = Result<C, ConnectionError>> + Send + 'static) {
466        self.futures.push(Box::pin(future));
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::{
474        Connector, OpensslDriver, RustlsDriver, Target, TlsParameters, TlsServerParameters,
475    };
476    use std::net::*;
477    use tokio::io::{AsyncReadExt, AsyncWriteExt};
478
479    async fn test_acceptor_new_tcp_previewing<D: TlsDriver>() -> Result<(), ConnectionError> {
480        let acceptor = Acceptor::new_tcp_tls_previewing(
481            SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
482            PreviewConfiguration::default(),
483            TlsServerParameterProvider::new(TlsServerParameters::new_with_certificate(
484                crate::test_keys::SERVER_KEY.clone_key(),
485            )),
486        );
487
488        let mut conns = acceptor.bind_explicit::<D>().await?;
489
490        let addr = conns.local_address()?;
491        tokio::task::spawn(async move {
492            let mut conn = Connector::new_resolved(addr).connect().await?;
493            conn.write_all(b"HELLO WORLD").await
494        });
495
496        let (preview, mut conn) = conns.next().await.unwrap()?;
497        assert_eq!(preview.len(), 8);
498        assert_eq!(preview, b"HELLO WO");
499        let mut string = String::new();
500        conn.read_to_string(&mut string).await?;
501        assert_eq!(string, "HELLO WORLD");
502
503        let addr = conns.local_address()?;
504        tokio::task::spawn(async move {
505            let target = Target::new_resolved_tls(addr, TlsParameters::insecure());
506            let mut conn = Connector::new(target)?.connect().await?;
507            conn.write_all(b"HELLO WORLD").await
508        });
509
510        let (preview, conn) = conns.next().await.unwrap()?;
511        assert_eq!(preview.len(), 8);
512        assert!(matches!(preview.as_ref(), [0x16, 3, 1, ..]));
513        let (preview, mut conn) = conn
514            .secure_upgrade_preview(PreviewConfiguration::default())
515            .await?;
516        assert_eq!(preview.len(), 8);
517        assert_eq!(preview, b"HELLO WO");
518
519        let mut string = String::new();
520        conn.read_to_string(&mut string).await?;
521        assert_eq!(string, "HELLO WORLD");
522
523        Ok(())
524    }
525
526    #[tokio::test]
527    async fn test_acceptor_new_tcp_previewing_openssl() -> Result<(), ConnectionError> {
528        test_acceptor_new_tcp_previewing::<OpensslDriver>().await
529    }
530
531    #[tokio::test]
532    async fn test_acceptor_new_tcp_previewing_rustls() -> Result<(), ConnectionError> {
533        test_acceptor_new_tcp_previewing::<RustlsDriver>().await
534    }
535}