1use rama_core::Context;
2use rama_core::Service;
3use rama_core::graceful::ShutdownGuard;
4use rama_core::rt::Executor;
5use rama_core::telemetry::tracing::{self, Instrument};
6use std::fmt;
7use std::io;
8use std::os::fd::AsFd;
9use std::os::fd::AsRawFd;
10use std::os::fd::BorrowedFd;
11use std::os::fd::RawFd;
12use std::os::unix::net::UnixListener as StdUnixListener;
13use std::path::Path;
14use std::path::PathBuf;
15use std::pin::pin;
16use std::sync::Arc;
17use tokio::net::UnixListener as TokioUnixListener;
18use tokio::net::unix::SocketAddr;
19
20#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
21use rama_net::socket::SocketOptions;
22
23use crate::UnixSocketAddress;
24use crate::UnixSocketInfo;
25use crate::UnixStream;
26
27pub struct UnixListenerBuilder<S> {
29 state: S,
30}
31
32impl<S> fmt::Debug for UnixListenerBuilder<S>
33where
34 S: fmt::Debug,
35{
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.debug_struct("UnixListenerBuilder")
38 .field("state", &self.state)
39 .finish()
40 }
41}
42
43impl UnixListenerBuilder<()> {
44 #[must_use]
46 pub fn new() -> Self {
47 Self { state: () }
48 }
49}
50
51impl Default for UnixListenerBuilder<()> {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl<S: Clone> Clone for UnixListenerBuilder<S> {
58 fn clone(&self) -> Self {
59 Self {
60 state: self.state.clone(),
61 }
62 }
63}
64
65impl<S> UnixListenerBuilder<S>
66where
67 S: Clone + Send + Sync + 'static,
68{
69 pub fn with_state(state: S) -> Self {
71 Self { state }
72 }
73}
74
75impl<S> UnixListenerBuilder<S>
76where
77 S: Clone + Send + Sync + 'static,
78{
79 pub async fn bind_path(self, path: impl AsRef<Path>) -> Result<UnixListener<S>, io::Error> {
83 let path = path.as_ref();
84
85 if tokio::fs::try_exists(path).await.unwrap_or_default() {
86 tracing::trace!(file.path = ?path, "try delete existing UNIX socket path");
87 tokio::fs::remove_file(path).await?;
91 }
92
93 let inner = TokioUnixListener::bind(path)?;
94 let cleanup = Some(UnixSocketCleanup {
95 path: path.to_owned(),
96 });
97
98 Ok(UnixListener {
99 inner,
100 state: self.state,
101 cleanup,
102 })
103 }
104
105 pub fn bind_socket(
109 self,
110 socket: rama_net::socket::core::Socket,
111 ) -> Result<UnixListener<S>, io::Error> {
112 let std_listener: StdUnixListener = socket.into();
113 std_listener.set_nonblocking(true)?;
114 let inner = TokioUnixListener::from_std(std_listener)?;
115 Ok(UnixListener {
116 inner,
117 state: self.state,
118 cleanup: None,
119 })
120 }
121
122 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
126 pub async fn bind_socket_opts(
127 self,
128 opts: SocketOptions,
129 ) -> Result<UnixListener<S>, rama_core::error::BoxError> {
130 let socket = tokio::task::spawn_blocking(move || opts.try_build_socket()).await??;
131 Ok(self.bind_socket(socket)?)
132 }
133}
134
135pub struct UnixListener<S> {
144 inner: TokioUnixListener,
145 state: S,
146 cleanup: Option<UnixSocketCleanup>,
147}
148
149impl<S> fmt::Debug for UnixListener<S>
150where
151 S: fmt::Debug,
152{
153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154 f.debug_struct("UnixListener")
155 .field("inner", &self.inner)
156 .field("state", &self.state)
157 .field("cleanup", &self.cleanup)
158 .finish()
159 }
160}
161
162impl UnixListener<()> {
163 #[inline]
164 #[must_use]
167 pub fn build() -> UnixListenerBuilder<()> {
168 UnixListenerBuilder::new()
169 }
170
171 #[inline]
172 pub fn build_with_state<S>(state: S) -> UnixListenerBuilder<S>
175 where
176 S: Clone + Send + Sync + 'static,
177 {
178 UnixListenerBuilder::with_state(state)
179 }
180
181 #[inline]
182 pub async fn bind_path(path: impl AsRef<Path>) -> Result<Self, io::Error> {
186 UnixListenerBuilder::default().bind_path(path).await
187 }
188
189 #[inline]
190 pub fn bind_socket(socket: rama_net::socket::core::Socket) -> Result<Self, io::Error> {
194 UnixListenerBuilder::default().bind_socket(socket)
195 }
196
197 #[inline]
198 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
199 pub async fn bind_socket_opts(opts: SocketOptions) -> Result<Self, rama_core::error::BoxError> {
203 UnixListenerBuilder::default().bind_socket_opts(opts).await
204 }
205}
206
207impl<S> UnixListener<S> {
208 pub fn local_addr(&self) -> io::Result<SocketAddr> {
213 self.inner.local_addr()
214 }
215
216 pub fn state(&self) -> &S {
218 &self.state
219 }
220
221 pub fn state_mut(&mut self) -> &mut S {
223 &mut self.state
224 }
225}
226
227impl From<TokioUnixListener> for UnixListener<()> {
228 fn from(value: TokioUnixListener) -> Self {
229 Self {
230 inner: value,
231 state: (),
232 cleanup: None,
233 }
234 }
235}
236
237impl TryFrom<rama_net::socket::core::Socket> for UnixListener<()> {
238 type Error = io::Error;
239
240 #[inline]
241 fn try_from(socket: rama_net::socket::core::Socket) -> Result<Self, Self::Error> {
242 Self::bind_socket(socket)
243 }
244}
245
246impl TryFrom<StdUnixListener> for UnixListener<()> {
247 type Error = io::Error;
248
249 fn try_from(listener: StdUnixListener) -> Result<Self, Self::Error> {
250 listener.set_nonblocking(true)?;
251 let inner = TokioUnixListener::from_std(listener)?;
252 Ok(Self {
253 inner,
254 state: (),
255 cleanup: None,
256 })
257 }
258}
259
260impl<S> AsRawFd for UnixListener<S> {
261 #[inline]
262 fn as_raw_fd(&self) -> RawFd {
263 self.inner.as_raw_fd()
264 }
265}
266
267impl<S> AsFd for UnixListener<S> {
268 #[inline]
269 fn as_fd(&self) -> BorrowedFd<'_> {
270 self.inner.as_fd()
271 }
272}
273
274impl UnixListener<()> {
275 pub fn with_state<S>(self, state: S) -> UnixListener<S> {
278 UnixListener {
279 inner: self.inner,
280 state,
281 cleanup: self.cleanup,
282 }
283 }
284}
285
286impl<State> UnixListener<State>
287where
288 State: Clone + Send + Sync + 'static,
289{
290 #[inline]
293 pub async fn accept(&self) -> io::Result<(UnixStream, UnixSocketAddress)> {
294 let (stream, addr) = self.inner.accept().await?;
295 Ok((stream, addr.into()))
296 }
297
298 pub async fn serve<S>(self, service: S)
303 where
304 S: Service<State, UnixStream>,
305 {
306 let ctx = Context::new(self.state, Executor::new());
307 let service = Arc::new(service);
308
309 loop {
310 let (socket, peer_addr) = match self.inner.accept().await {
311 Ok(stream) => stream,
312 Err(err) => {
313 handle_accept_err(err).await;
314 continue;
315 }
316 };
317
318 let service = service.clone();
319 let mut ctx = ctx.clone();
320
321 let peer_addr: UnixSocketAddress = peer_addr.into();
322 let local_addr: Option<UnixSocketAddress> = socket.local_addr().ok().map(Into::into);
323
324 let serve_span = tracing::trace_root_span!(
325 "unix::serve",
326 otel.kind = "server",
327 network.local.address = ?local_addr,
328 network.peer.address = ?peer_addr,
329 network.protocol.name = "uds",
330 );
331
332 tokio::spawn(
333 async move {
334 ctx.insert(UnixSocketInfo::new(socket.local_addr().ok(), peer_addr));
335 let _ = service.serve(ctx, socket).await;
336 }
337 .instrument(serve_span),
338 );
339 }
340 }
341
342 pub async fn serve_graceful<S>(self, guard: ShutdownGuard, service: S)
348 where
349 S: Service<State, UnixStream>,
350 {
351 let ctx: Context<State> = Context::new(self.state, Executor::graceful(guard.clone()));
352 let service = Arc::new(service);
353 let mut cancelled_fut = pin!(guard.cancelled());
354
355 loop {
356 tokio::select! {
357 _ = cancelled_fut.as_mut() => {
358 tracing::trace!("signal received: initiate graceful shutdown");
359 break;
360 }
361 result = self.inner.accept() => {
362 match result {
363 Ok((socket, peer_addr)) => {
364 let service = service.clone();
365 let mut ctx = ctx.clone();
366
367 let peer_addr: UnixSocketAddress = peer_addr.into();
368 let local_addr: Option<UnixSocketAddress> = socket.local_addr().ok().map(Into::into);
369
370 let serve_span = tracing::trace_root_span!(
371 "unix::serve_graceful",
372 otel.kind = "server",
373 network.local.address = ?local_addr,
374 network.peer.address = ?peer_addr,
375 network.protocol.name = "uds",
376 );
377
378 guard.spawn_task(async move {
379 ctx.insert(UnixSocketInfo::new(local_addr, peer_addr));
380
381 let _ = service.serve(ctx, socket).await;
382 }.instrument(serve_span));
383 }
384 Err(err) => {
385 handle_accept_err(err).await;
386 }
387 }
388 }
389 }
390 }
391 }
392}
393
394async fn handle_accept_err(err: io::Error) {
395 if rama_net::conn::is_connection_error(&err) {
396 tracing::trace!("unix accept error: connect error: {err:?}");
397 } else {
398 tracing::error!("unix accept error: {err:?}");
399 }
400}
401
402#[derive(Debug)]
403struct UnixSocketCleanup {
404 path: PathBuf,
405}
406
407impl Drop for UnixSocketCleanup {
408 fn drop(&mut self) {
409 if let Err(err) = std::fs::remove_file(&self.path) {
410 tracing::debug!(file.path = ?self.path, "failed to remove unix listener's file socket {err:?}");
411 }
412 }
413}