gel_stream/server/
acceptor.rs

1use crate::{
2    common::tokio_stream::TokioListenerStream, ConnectionError, LocalAddress, ResolvedTarget,
3    RewindStream, Ssl, SslError, StreamUpgrade, TlsDriver, TlsServerParameterProvider,
4    UpgradableStream,
5};
6use futures::{FutureExt, StreamExt};
7use std::{
8    future::Future,
9    pin::Pin,
10    task::{ready, Poll},
11};
12use std::{net::SocketAddr, path::Path};
13
14use super::Connection;
15
16pub struct Acceptor {
17    resolved_target: ResolvedTarget,
18    tls_provider: Option<TlsServerParameterProvider>,
19    should_upgrade: bool,
20    ignore_missing_tls_close_notify: bool,
21}
22
23impl Acceptor {
24    pub fn new(target: ResolvedTarget) -> Self {
25        Self {
26            resolved_target: target,
27            tls_provider: None,
28            should_upgrade: false,
29            ignore_missing_tls_close_notify: false,
30        }
31    }
32
33    pub fn new_tls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
34        Self {
35            resolved_target: target,
36            tls_provider: Some(provider),
37            should_upgrade: true,
38            ignore_missing_tls_close_notify: false,
39        }
40    }
41
42    pub fn new_starttls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
43        Self {
44            resolved_target: target,
45            tls_provider: Some(provider),
46            should_upgrade: false,
47            ignore_missing_tls_close_notify: false,
48        }
49    }
50
51    pub fn new_tcp(addr: SocketAddr) -> Self {
52        Self {
53            resolved_target: ResolvedTarget::SocketAddr(addr),
54            tls_provider: None,
55            should_upgrade: false,
56            ignore_missing_tls_close_notify: false,
57        }
58    }
59
60    pub fn new_tcp_tls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
61        Self {
62            resolved_target: ResolvedTarget::SocketAddr(addr),
63            tls_provider: Some(provider),
64            should_upgrade: true,
65            ignore_missing_tls_close_notify: false,
66        }
67    }
68
69    pub fn new_tcp_starttls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
70        Self {
71            resolved_target: ResolvedTarget::SocketAddr(addr),
72            tls_provider: Some(provider),
73            should_upgrade: false,
74            ignore_missing_tls_close_notify: false,
75        }
76    }
77
78    pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
79        #[cfg(unix)]
80        {
81            Ok(Self {
82                resolved_target: ResolvedTarget::from(
83                    std::os::unix::net::SocketAddr::from_pathname(path)?,
84                ),
85                tls_provider: None,
86                should_upgrade: false,
87                ignore_missing_tls_close_notify: false,
88            })
89        }
90        #[cfg(not(unix))]
91        {
92            Err(std::io::Error::new(
93                std::io::ErrorKind::Unsupported,
94                "Unix domain sockets are not supported on this platform",
95            ))
96        }
97    }
98
99    pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
100        #[cfg(any(target_os = "linux", target_os = "android"))]
101        {
102            use std::os::linux::net::SocketAddrExt;
103            Ok(Self {
104                resolved_target: ResolvedTarget::from(
105                    std::os::unix::net::SocketAddr::from_abstract_name(domain)?,
106                ),
107                tls_provider: None,
108                should_upgrade: false,
109                ignore_missing_tls_close_notify: false,
110            })
111        }
112        #[cfg(not(any(target_os = "linux", target_os = "android")))]
113        {
114            Err(std::io::Error::new(
115                std::io::ErrorKind::Unsupported,
116                "Unix domain sockets are not supported on this platform",
117            ))
118        }
119    }
120
121    pub async fn bind(
122        self,
123    ) -> Result<
124        impl ::futures::Stream<Item = Result<Connection, ConnectionError>> + LocalAddress,
125        ConnectionError,
126    > {
127        let stream = self.resolved_target.listen_raw().await?;
128        Ok(AcceptedStream {
129            stream,
130            should_upgrade: self.should_upgrade,
131            ignore_missing_tls_close_notify: self.ignore_missing_tls_close_notify,
132            upgrade_future: None,
133            tls_provider: self.tls_provider,
134            _phantom: None,
135        })
136    }
137
138    #[allow(private_bounds)]
139    pub async fn bind_explicit<D: TlsDriver>(
140        self,
141    ) -> Result<
142        impl ::futures::Stream<Item = Result<Connection<D>, ConnectionError>> + LocalAddress,
143        ConnectionError,
144    > {
145        let stream = self.resolved_target.listen_raw().await?;
146        Ok(AcceptedStream {
147            stream,
148            ignore_missing_tls_close_notify: self.ignore_missing_tls_close_notify,
149            should_upgrade: self.should_upgrade,
150            upgrade_future: None,
151            tls_provider: self.tls_provider,
152            _phantom: None,
153        })
154    }
155
156    pub async fn accept_one(self) -> Result<Connection, std::io::Error> {
157        let mut stream = self.resolved_target.listen().await?;
158        let (stream, _target) = stream.next().await.unwrap()?;
159        let mut stm = UpgradableStream::new_server(
160            RewindStream::new(stream),
161            None::<TlsServerParameterProvider>,
162        );
163        if self.ignore_missing_tls_close_notify {
164            stm.ignore_missing_close_notify();
165        }
166        Ok(stm)
167    }
168}
169
170struct AcceptedStream<D: TlsDriver = Ssl> {
171    stream: TokioListenerStream,
172    should_upgrade: bool,
173    ignore_missing_tls_close_notify: bool,
174    tls_provider: Option<TlsServerParameterProvider>,
175    #[allow(clippy::type_complexity)]
176    upgrade_future:
177        Option<Pin<Box<dyn Future<Output = Result<Connection<D>, SslError>> + Send + 'static>>>,
178    // Avoid using PhantomData because it fails to implement certain auto-traits
179    _phantom: Option<&'static D>,
180}
181
182impl<D: TlsDriver> LocalAddress for AcceptedStream<D> {
183    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
184        self.stream.local_address()
185    }
186}
187
188impl<D: TlsDriver> futures::Stream for AcceptedStream<D> {
189    type Item = Result<Connection<D>, ConnectionError>;
190
191    fn poll_next(
192        mut self: std::pin::Pin<&mut Self>,
193        cx: &mut std::task::Context<'_>,
194    ) -> Poll<Option<Self::Item>> {
195        if let Some(mut upgrade_future) = self.upgrade_future.take() {
196            match upgrade_future.poll_unpin(cx) {
197                Poll::Ready(Ok(conn)) => {
198                    return Poll::Ready(Some(Ok(conn)));
199                }
200                Poll::Ready(Err(e)) => {
201                    return Poll::Ready(Some(Err(e.into())));
202                }
203                Poll::Pending => {
204                    self.upgrade_future = Some(upgrade_future);
205                    return Poll::Pending;
206                }
207            }
208        }
209        let r = ready!(self.stream.poll_next_unpin(cx));
210        let Some(r) = r else {
211            return Poll::Ready(None);
212        };
213        let (stream, _target) = r?;
214        let mut stream =
215            UpgradableStream::new_server(RewindStream::new(stream), self.tls_provider.clone());
216        if self.ignore_missing_tls_close_notify {
217            stream.ignore_missing_close_notify();
218        }
219        if self.should_upgrade {
220            let mut upgrade_future = Box::pin(async move {
221                stream.secure_upgrade().await?;
222                Ok::<_, SslError>(stream)
223            });
224            match upgrade_future.poll_unpin(cx) {
225                Poll::Ready(Ok(stream)) => {
226                    return Poll::Ready(Some(Ok(stream)));
227                }
228                Poll::Ready(Err(e)) => {
229                    return Poll::Ready(Some(Err(e.into())));
230                }
231                Poll::Pending => {
232                    self.upgrade_future = Some(upgrade_future);
233                    return Poll::Pending;
234                }
235            }
236        }
237        Poll::Ready(Some(Ok(stream)))
238    }
239}