rama_unix/unix/server/
listener.rs

1use rama_core::Context;
2use rama_core::Service;
3use rama_core::graceful::ShutdownGuard;
4use rama_core::rt::Executor;
5use rama_core::telemetry::tracing::{self, Instrument};
6use std::fmt;
7use std::io;
8use std::os::fd::AsFd;
9use std::os::fd::AsRawFd;
10use std::os::fd::BorrowedFd;
11use std::os::fd::RawFd;
12use std::os::unix::net::UnixListener as StdUnixListener;
13use std::path::Path;
14use std::path::PathBuf;
15use std::pin::pin;
16use std::sync::Arc;
17use tokio::net::UnixListener as TokioUnixListener;
18use tokio::net::unix::SocketAddr;
19
20#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
21use rama_net::socket::SocketOptions;
22
23use crate::UnixSocketAddress;
24use crate::UnixSocketInfo;
25use crate::UnixStream;
26
27/// Builder for `UnixListener`.
28pub struct UnixListenerBuilder<S> {
29    state: S,
30}
31
32impl<S> fmt::Debug for UnixListenerBuilder<S>
33where
34    S: fmt::Debug,
35{
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.debug_struct("UnixListenerBuilder")
38            .field("state", &self.state)
39            .finish()
40    }
41}
42
43impl UnixListenerBuilder<()> {
44    /// Create a new `UnixListenerBuilder` without a state.
45    #[must_use]
46    pub fn new() -> Self {
47        Self { state: () }
48    }
49}
50
51impl Default for UnixListenerBuilder<()> {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl<S: Clone> Clone for UnixListenerBuilder<S> {
58    fn clone(&self) -> Self {
59        Self {
60            state: self.state.clone(),
61        }
62    }
63}
64
65impl<S> UnixListenerBuilder<S>
66where
67    S: Clone + Send + Sync + 'static,
68{
69    /// Create a new `TcpListenerBuilder` with the given state.
70    pub fn with_state(state: S) -> Self {
71        Self { state }
72    }
73}
74
75impl<S> UnixListenerBuilder<S>
76where
77    S: Clone + Send + Sync + 'static,
78{
79    /// Creates a new [`UnixListener`], which will be bound to the specified path.
80    ///
81    /// The returned listener is ready for accepting connections.
82    pub async fn bind_path(self, path: impl AsRef<Path>) -> Result<UnixListener<S>, io::Error> {
83        let path = path.as_ref();
84
85        if tokio::fs::try_exists(path).await.unwrap_or_default() {
86            tracing::trace!(file.path = ?path, "try delete existing UNIX socket path");
87            // some errors might lead to false positives (e.g. no permissions),
88            // this is ok as this is a best-effort cleanup to anyway only be of use
89            // if we have permission to do so
90            tokio::fs::remove_file(path).await?;
91        }
92
93        let inner = TokioUnixListener::bind(path)?;
94        let cleanup = Some(UnixSocketCleanup {
95            path: path.to_owned(),
96        });
97
98        Ok(UnixListener {
99            inner,
100            state: self.state,
101            cleanup,
102        })
103    }
104
105    /// Creates a new [`UnixListener`], which will be bound to the specified socket.
106    ///
107    /// The returned listener is ready for accepting connections.
108    pub fn bind_socket(
109        self,
110        socket: rama_net::socket::core::Socket,
111    ) -> Result<UnixListener<S>, io::Error> {
112        let std_listener: StdUnixListener = socket.into();
113        std_listener.set_nonblocking(true)?;
114        let inner = TokioUnixListener::from_std(std_listener)?;
115        Ok(UnixListener {
116            inner,
117            state: self.state,
118            cleanup: None,
119        })
120    }
121
122    /// Creates a new TcpListener, which will be bound to the specified interface.
123    ///
124    /// The returned listener is ready for accepting connections.
125    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
126    pub async fn bind_socket_opts(
127        self,
128        opts: SocketOptions,
129    ) -> Result<UnixListener<S>, rama_core::error::BoxError> {
130        let socket = tokio::task::spawn_blocking(move || opts.try_build_socket()).await??;
131        Ok(self.bind_socket(socket)?)
132    }
133}
134
135/// A Unix (domain) socket server, listening for incoming connections once served
136/// using one of the `serve` methods such as [`UnixListener::serve`].
137///
138/// Note that the underlying socket (file) is only cleaned up
139/// by this listener's [`Drop`] implementation if the listener
140/// was created using the `bind_path` constructor. Otherwise
141/// it is assumed that the creator of this listener is in charge
142/// of that cleanup.
143pub struct UnixListener<S> {
144    inner: TokioUnixListener,
145    state: S,
146    cleanup: Option<UnixSocketCleanup>,
147}
148
149impl<S> fmt::Debug for UnixListener<S>
150where
151    S: fmt::Debug,
152{
153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154        f.debug_struct("UnixListener")
155            .field("inner", &self.inner)
156            .field("state", &self.state)
157            .field("cleanup", &self.cleanup)
158            .finish()
159    }
160}
161
162impl UnixListener<()> {
163    #[inline]
164    /// Create a new [`UnixListenerBuilder`] without a state,
165    /// which can be used to configure a [`UnixListener`].
166    #[must_use]
167    pub fn build() -> UnixListenerBuilder<()> {
168        UnixListenerBuilder::new()
169    }
170
171    #[inline]
172    /// Create a new [`UnixListenerBuilder`] with the given state,
173    /// which can be used to configure a [`UnixListener`].
174    pub fn build_with_state<S>(state: S) -> UnixListenerBuilder<S>
175    where
176        S: Clone + Send + Sync + 'static,
177    {
178        UnixListenerBuilder::with_state(state)
179    }
180
181    #[inline]
182    /// Creates a new [`UnixListener`], which will be bound to the specified path.
183    ///
184    /// The returned listener is ready for accepting connections.
185    pub async fn bind_path(path: impl AsRef<Path>) -> Result<Self, io::Error> {
186        UnixListenerBuilder::default().bind_path(path).await
187    }
188
189    #[inline]
190    /// Creates a new [`UnixListener`], which will be bound to the specified socket.
191    ///
192    /// The returned listener is ready for accepting connections.
193    pub fn bind_socket(socket: rama_net::socket::core::Socket) -> Result<Self, io::Error> {
194        UnixListenerBuilder::default().bind_socket(socket)
195    }
196
197    #[inline]
198    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
199    /// Creates a new TcpListener, which will be bound to the specified (interface) device name.
200    ///
201    /// The returned listener is ready for accepting connections.
202    pub async fn bind_socket_opts(opts: SocketOptions) -> Result<Self, rama_core::error::BoxError> {
203        UnixListenerBuilder::default().bind_socket_opts(opts).await
204    }
205}
206
207impl<S> UnixListener<S> {
208    /// Returns the local address that this listener is bound to.
209    ///
210    /// This can be useful, for example, when binding to port 0 to figure out
211    /// which port was actually bound.
212    pub fn local_addr(&self) -> io::Result<SocketAddr> {
213        self.inner.local_addr()
214    }
215
216    /// Gets a reference to the listener's state.
217    pub fn state(&self) -> &S {
218        &self.state
219    }
220
221    /// Gets an exclusive reference to the listener's state.
222    pub fn state_mut(&mut self) -> &mut S {
223        &mut self.state
224    }
225}
226
227impl From<TokioUnixListener> for UnixListener<()> {
228    fn from(value: TokioUnixListener) -> Self {
229        Self {
230            inner: value,
231            state: (),
232            cleanup: None,
233        }
234    }
235}
236
237impl TryFrom<rama_net::socket::core::Socket> for UnixListener<()> {
238    type Error = io::Error;
239
240    #[inline]
241    fn try_from(socket: rama_net::socket::core::Socket) -> Result<Self, Self::Error> {
242        Self::bind_socket(socket)
243    }
244}
245
246impl TryFrom<StdUnixListener> for UnixListener<()> {
247    type Error = io::Error;
248
249    fn try_from(listener: StdUnixListener) -> Result<Self, Self::Error> {
250        listener.set_nonblocking(true)?;
251        let inner = TokioUnixListener::from_std(listener)?;
252        Ok(Self {
253            inner,
254            state: (),
255            cleanup: None,
256        })
257    }
258}
259
260impl<S> AsRawFd for UnixListener<S> {
261    #[inline]
262    fn as_raw_fd(&self) -> RawFd {
263        self.inner.as_raw_fd()
264    }
265}
266
267impl<S> AsFd for UnixListener<S> {
268    #[inline]
269    fn as_fd(&self) -> BorrowedFd<'_> {
270        self.inner.as_fd()
271    }
272}
273
274impl UnixListener<()> {
275    /// Define the TcpListener's state after it was created,
276    /// useful in case it wasn't built using the builder.
277    pub fn with_state<S>(self, state: S) -> UnixListener<S> {
278        UnixListener {
279            inner: self.inner,
280            state,
281            cleanup: self.cleanup,
282        }
283    }
284}
285
286impl<State> UnixListener<State>
287where
288    State: Clone + Send + Sync + 'static,
289{
290    /// Accept a single connection from this listener,
291    /// what you can do with whatever you want.
292    #[inline]
293    pub async fn accept(&self) -> io::Result<(UnixStream, UnixSocketAddress)> {
294        let (stream, addr) = self.inner.accept().await?;
295        Ok((stream, addr.into()))
296    }
297
298    /// Serve connections from this listener with the given service.
299    ///
300    /// This method will block the current listener for each incoming connection,
301    /// the underlying service can choose to spawn a task to handle the accepted stream.
302    pub async fn serve<S>(self, service: S)
303    where
304        S: Service<State, UnixStream>,
305    {
306        let ctx = Context::new(self.state, Executor::new());
307        let service = Arc::new(service);
308
309        loop {
310            let (socket, peer_addr) = match self.inner.accept().await {
311                Ok(stream) => stream,
312                Err(err) => {
313                    handle_accept_err(err).await;
314                    continue;
315                }
316            };
317
318            let service = service.clone();
319            let mut ctx = ctx.clone();
320
321            let peer_addr: UnixSocketAddress = peer_addr.into();
322            let local_addr: Option<UnixSocketAddress> = socket.local_addr().ok().map(Into::into);
323
324            let serve_span = tracing::trace_root_span!(
325                "unix::serve",
326                otel.kind = "server",
327                network.local.address = ?local_addr,
328                network.peer.address = ?peer_addr,
329                network.protocol.name = "uds",
330            );
331
332            tokio::spawn(
333                async move {
334                    ctx.insert(UnixSocketInfo::new(socket.local_addr().ok(), peer_addr));
335                    let _ = service.serve(ctx, socket).await;
336                }
337                .instrument(serve_span),
338            );
339        }
340    }
341
342    /// Serve gracefully connections from this listener with the given service.
343    ///
344    /// This method does the same as [`Self::serve`] but it
345    /// will respect the given [`rama_core::graceful::ShutdownGuard`], and also pass
346    /// it to the service.
347    pub async fn serve_graceful<S>(self, guard: ShutdownGuard, service: S)
348    where
349        S: Service<State, UnixStream>,
350    {
351        let ctx: Context<State> = Context::new(self.state, Executor::graceful(guard.clone()));
352        let service = Arc::new(service);
353        let mut cancelled_fut = pin!(guard.cancelled());
354
355        loop {
356            tokio::select! {
357                _ = cancelled_fut.as_mut() => {
358                    tracing::trace!("signal received: initiate graceful shutdown");
359                    break;
360                }
361                result = self.inner.accept() => {
362                    match result {
363                        Ok((socket, peer_addr)) => {
364                            let service = service.clone();
365                            let mut ctx = ctx.clone();
366
367                            let peer_addr: UnixSocketAddress = peer_addr.into();
368                            let local_addr: Option<UnixSocketAddress> = socket.local_addr().ok().map(Into::into);
369
370                            let serve_span = tracing::trace_root_span!(
371                                "unix::serve_graceful",
372                                otel.kind = "server",
373                                network.local.address = ?local_addr,
374                                network.peer.address = ?peer_addr,
375                                network.protocol.name = "uds",
376                            );
377
378                            guard.spawn_task(async move {
379                                ctx.insert(UnixSocketInfo::new(local_addr, peer_addr));
380
381                                let _ = service.serve(ctx, socket).await;
382                            }.instrument(serve_span));
383                        }
384                        Err(err) => {
385                            handle_accept_err(err).await;
386                        }
387                    }
388                }
389            }
390        }
391    }
392}
393
394async fn handle_accept_err(err: io::Error) {
395    if rama_net::conn::is_connection_error(&err) {
396        tracing::trace!("unix accept error: connect error: {err:?}");
397    } else {
398        tracing::error!("unix accept error: {err:?}");
399    }
400}
401
402#[derive(Debug)]
403struct UnixSocketCleanup {
404    path: PathBuf,
405}
406
407impl Drop for UnixSocketCleanup {
408    fn drop(&mut self) {
409        if let Err(err) = std::fs::remove_file(&self.path) {
410            tracing::debug!(file.path = ?self.path, "failed to remove unix listener's file socket {err:?}");
411        }
412    }
413}