axum_listener/
dual.rs

1use axum::serve::Listener;
2use std::net::ToSocketAddrs;
3#[cfg(unix)]
4use tokio::net::UnixListener;
5
6/// A unified listener that can bind to either TCP or Unix Domain Socket addresses.
7///
8/// This enum allows you to create a single listener type that can handle both TCP and UDS
9/// connections transparently. The specific variant is determined at runtime based on the
10/// address format provided to [`DualListener::bind`].
11///
12/// # Examples
13///
14/// ```rust,no_run
15/// # tokio_test::block_on(async {
16/// use axum_listener::listener::DualListener;
17///
18/// // Bind to TCP
19/// let tcp_listener = DualListener::bind("localhost:8080").await.unwrap();
20///
21/// // Bind to Unix Domain Socket (on Unix systems)
22/// # #[cfg(unix)] {
23/// let uds_listener = DualListener::bind("unix:/tmp/app.sock").await.unwrap();
24/// # }
25/// # });
26/// ```
27///
28/// # Platform Support
29///
30/// - `Tcp` variant is available on all platforms
31/// - `Uds` variant is only available on Unix-like systems
32#[derive(Debug)]
33pub enum DualListener {
34    /// A TCP listener for network connections
35    Tcp(tokio::net::TcpListener),
36    /// A Unix Domain Socket listener for local inter-process communication
37    #[cfg(unix)]
38    Uds(tokio::net::UnixListener),
39}
40
41/// An address that can represent either a TCP socket address or a Unix Domain Socket address.
42///
43/// This enum is used to represent the local and remote addresses for connections
44/// accepted by [`DualListener`]. It automatically implements cleanup for UDS
45/// socket files when the `remove-on-drop` feature is enabled.
46///
47/// # Examples
48///
49/// ```rust
50/// use axum_listener::listener::DualAddr;
51/// use std::str::FromStr;
52///
53/// // Parse a TCP address
54/// let tcp_addr = DualAddr::from_str("127.0.0.1:8080").unwrap();
55///
56/// // Parse a UDS address (on Unix systems)
57/// # #[cfg(unix)] {
58/// let uds_addr = DualAddr::from_str("unix:/tmp/app.sock").unwrap();
59/// # }
60/// ```
61#[derive(Debug, Clone)]
62#[allow(dead_code)]
63pub enum DualAddr {
64    /// A TCP socket address (IPv4 or IPv6)
65    Tcp(core::net::SocketAddr),
66    /// A Unix Domain Socket address
67    #[cfg(unix)]
68    Uds(tokio::net::unix::SocketAddr),
69}
70
71impl From<core::net::SocketAddr> for DualAddr {
72    fn from(addr: core::net::SocketAddr) -> Self {
73        DualAddr::Tcp(addr)
74    }
75}
76
77#[cfg(unix)]
78impl From<tokio::net::unix::SocketAddr> for DualAddr {
79    fn from(addr: tokio::net::unix::SocketAddr) -> Self {
80        DualAddr::Uds(addr)
81    }
82}
83
84impl core::str::FromStr for DualAddr {
85    type Err = std::io::Error;
86
87    fn from_str(s: &str) -> Result<Self, Self::Err> {
88        let unix_like = s.starts_with("/") || s.starts_with("unix:");
89        let has_uds = cfg!(unix);
90        let tcp_like = s.to_socket_addrs().is_ok();
91
92        if unix_like && has_uds && !tcp_like {
93            #[cfg(unix)]
94            {
95                let path = s.trim_start_matches("unix:");
96                let addr = From::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
97                Ok(DualAddr::Uds(addr))
98            }
99            #[cfg(not(unix))]
100            {
101                Err(std::io::Error::new(
102                    std::io::ErrorKind::Other,
103                    "Unix domain sockets are not supported on this platform",
104                ))
105            }
106        } else if tcp_like {
107            let addr = s.to_socket_addrs()?.next().ok_or_else(|| {
108                std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid TCP address")
109            })?;
110            Ok(DualAddr::Tcp(addr))
111        } else if unix_like && !has_uds {
112            Err(std::io::Error::other(
113                "Unix domain sockets are not supported on this platform",
114            ))
115        } else {
116            Err(std::io::Error::new(
117                std::io::ErrorKind::InvalidInput,
118                "Invalid address format",
119            ))
120        }
121    }
122}
123
124/// A trait for types that can be converted to a [`DualAddr`].
125///
126/// This trait enables convenient address binding by allowing various types
127/// to be converted to the unified [`DualAddr`] type. It's implemented for
128/// common address types including strings, socket addresses, and paths.
129///
130/// # Examples
131///
132/// ```rust
133/// use axum_listener::listener::{ToDualAddr, DualAddr};
134/// use std::net::SocketAddr;
135///
136/// // String addresses
137/// let addr1 = "127.0.0.1:8080".to_dual_addr().unwrap();
138/// let addr2 = "unix:/tmp/app.sock".to_dual_addr().unwrap();
139///
140/// // Socket address
141/// let socket_addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
142/// let addr3 = socket_addr.to_dual_addr().unwrap();
143/// ```
144pub trait ToDualAddr {
145    /// Convert this type to a [`DualAddr`].
146    ///
147    /// # Errors
148    ///
149    /// Returns an [`std::io::Error`] if the address format is invalid or
150    /// if Unix Domain Sockets are not supported on the current platform.
151    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error>;
152}
153
154impl ToDualAddr for &str {
155    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
156        self.parse()
157    }
158}
159
160impl ToDualAddr for String {
161    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
162        self.as_str().to_dual_addr()
163    }
164}
165
166impl ToDualAddr for core::net::SocketAddr {
167    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
168        Ok(DualAddr::Tcp(*self))
169    }
170}
171
172#[cfg(unix)]
173impl ToDualAddr for tokio::net::unix::SocketAddr {
174    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
175        Ok(DualAddr::Uds(self.clone()))
176    }
177}
178
179impl ToDualAddr for DualAddr {
180    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
181        Ok(self.clone())
182    }
183}
184
185impl ToDualAddr for &DualAddr {
186    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
187        Ok((*self).clone())
188    }
189}
190
191#[cfg(unix)]
192impl ToDualAddr for &std::path::Path {
193    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
194        Ok(DualAddr::Uds(From::from(
195            std::os::unix::net::SocketAddr::from_pathname(self)?,
196        )))
197    }
198}
199
200#[cfg(unix)]
201impl ToDualAddr for std::path::PathBuf {
202    fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
203        self.as_path().to_dual_addr()
204    }
205}
206
207impl DualListener {
208    /// Creates a new [`DualListener`] bound to the specified address.
209    ///
210    /// This method accepts any type that implements [`ToDualAddr`], allowing
211    /// for flexible address specification. The listener type (TCP or UDS) is
212    /// automatically determined based on the address format.
213    ///
214    /// # Arguments
215    ///
216    /// * `address` - An address that can be converted to [`DualAddr`]
217    ///
218    /// # Returns
219    ///
220    /// Returns a [`DualListener`] bound to the specified address, or an error
221    /// if binding fails.
222    ///
223    /// # Examples
224    ///
225    /// ```rust,no_run
226    /// # tokio_test::block_on(async {
227    /// use axum_listener::listener::DualListener;
228    ///
229    /// // Bind to TCP address
230    /// let listener = DualListener::bind("localhost:8080").await.unwrap();
231    ///
232    /// // Bind to UDS address (Unix only)
233    /// # #[cfg(unix)] {
234    /// let listener = DualListener::bind("unix:/tmp/app.sock").await.unwrap();
235    /// # }
236    /// # });
237    /// ```
238    ///
239    /// # Errors
240    ///
241    /// This method can fail if:
242    /// - The address format is invalid
243    /// - The address is already in use
244    /// - Permission is denied for the requested address
245    /// - Unix Domain Sockets are not supported on the current platform
246    pub async fn bind<A: ToDualAddr>(address: A) -> Result<Self, std::io::Error> {
247        let address = address.to_dual_addr()?;
248        match address {
249            DualAddr::Tcp(addr) => {
250                let listener = tokio::net::TcpListener::bind(addr).await?;
251                Ok(DualListener::Tcp(listener))
252            }
253            #[cfg(unix)]
254            DualAddr::Uds(ref addr) => {
255                let path = addr.as_pathname().ok_or_else(|| {
256                    std::io::Error::new(
257                        std::io::ErrorKind::InvalidInput,
258                        "UDS address does not have a valid pathname",
259                    )
260                })?;
261                let listener = UnixListener::bind(path)?;
262                Ok(DualListener::Uds(listener))
263            }
264            #[cfg(not(unix))]
265            DualAddr::Uds(_) => Err(std::io::Error::new(
266                std::io::ErrorKind::Other,
267                "Unix domain sockets are not supported on this platform",
268            )),
269        }
270    }
271
272    /// Accepts a new incoming connection from this listener.
273    ///
274    /// This method will wait for a connection to be established and return
275    /// a stream and address representing the connection.
276    ///
277    /// # Returns
278    ///
279    /// Returns a tuple containing:
280    /// - [`DualStream`]: The stream for communicating with the client
281    /// - [`DualAddr`]: The address of the connected client
282    ///
283    /// # Examples
284    ///
285    /// ```rust,no_run
286    /// # tokio_test::block_on(async {
287    /// use axum_listener::listener::DualListener;
288    ///
289    /// let listener = DualListener::bind("localhost:8080").await.unwrap();
290    ///
291    /// // Accept a connection
292    /// let (stream, addr) = listener.accept().await.unwrap();
293    /// println!("Accepted connection from: {:?}", addr);
294    /// # });
295    /// ```
296    ///
297    /// # Errors
298    ///
299    /// This method can fail if there's an I/O error while accepting the connection.
300    pub async fn accept(&self) -> Result<(DualStream, DualAddr), std::io::Error> {
301        match self {
302            DualListener::Tcp(listener) => {
303                let (stream, addr) = listener.accept().await?;
304                Ok((DualStream::Tcp(stream), DualAddr::Tcp(addr)))
305            }
306            #[cfg(unix)]
307            DualListener::Uds(listener) => {
308                let (stream, addr) = listener.accept().await?;
309                Ok((DualStream::Uds(stream), DualAddr::Uds(addr)))
310            }
311        }
312    }
313
314    pub(crate) fn _accept_unpin(
315        &self,
316    ) -> impl core::future::Future<Output = Result<(DualStream, DualAddr), std::io::Error>>
317    + Unpin
318    + use<'_> {
319        Box::pin(async move {
320            match self {
321                DualListener::Tcp(listener) => {
322                    let (stream, addr) = listener.accept().await?;
323                    Ok((DualStream::Tcp(stream), DualAddr::Tcp(addr)))
324                }
325                #[cfg(unix)]
326                DualListener::Uds(listener) => {
327                    let (stream, addr) = listener.accept().await?;
328                    Ok((DualStream::Uds(stream), DualAddr::Uds(addr)))
329                }
330            }
331        })
332    }
333    pub(crate) async fn _accept_axum(&mut self) -> (DualStream, DualAddr) {
334        match self {
335            DualListener::Tcp(listener) => {
336                let (stream, addr) = Listener::accept(listener).await;
337                (DualStream::Tcp(stream), DualAddr::Tcp(addr))
338            }
339            #[cfg(unix)]
340            DualListener::Uds(listener) => {
341                let (stream, addr) = Listener::accept(listener).await;
342                (DualStream::Uds(stream), DualAddr::Uds(addr))
343            }
344        }
345    }
346
347    pub(crate) fn _accept_axum_unpin(
348        &mut self,
349    ) -> impl core::future::Future<Output = (DualStream, DualAddr)> + Unpin + use<'_> {
350        Box::pin(async move {
351            match self {
352                DualListener::Tcp(listener) => {
353                    let (stream, addr) = Listener::accept(listener).await;
354                    (DualStream::Tcp(stream), DualAddr::Tcp(addr))
355                }
356                #[cfg(unix)]
357                DualListener::Uds(listener) => {
358                    let (stream, addr) = Listener::accept(listener).await;
359                    (DualStream::Uds(stream), DualAddr::Uds(addr))
360                }
361            }
362        })
363    }
364}
365
366/// A stream that can be either a TCP stream or a Unix Domain Socket stream.
367///
368/// This enum provides a unified interface for both TCP and UDS connections,
369/// implementing the necessary async I/O traits to work seamlessly with Axum
370/// and other async frameworks.
371///
372/// # Examples
373///
374/// ```rust,no_run
375/// # tokio_test::block_on(async {
376/// use axum_listener::listener::DualListener;
377///
378/// let listener = DualListener::bind("localhost:8080").await.unwrap();
379/// let (stream, _addr) = listener.accept().await.unwrap();
380///
381/// // The stream can be used with Axum or any other async framework
382/// // that works with tokio's AsyncRead + AsyncWrite traits
383/// println!("Accepted connection from: {:?}", _addr);
384/// # });
385/// ```
386pub enum DualStream {
387    /// A TCP stream for network connections
388    Tcp(tokio::net::TcpStream),
389    /// A Unix Domain Socket stream for local inter-process communication
390    #[cfg(unix)]
391    Uds(tokio::net::UnixStream),
392}
393
394impl From<tokio::net::TcpStream> for DualStream {
395    fn from(stream: tokio::net::TcpStream) -> Self {
396        DualStream::Tcp(stream)
397    }
398}
399
400#[cfg(unix)]
401impl From<tokio::net::UnixStream> for DualStream {
402    fn from(stream: tokio::net::UnixStream) -> Self {
403        DualStream::Uds(stream)
404    }
405}
406
407impl tokio::io::AsyncRead for DualStream {
408    fn poll_read(
409        self: std::pin::Pin<&mut Self>,
410        cx: &mut std::task::Context<'_>,
411        buf: &mut tokio::io::ReadBuf<'_>,
412    ) -> std::task::Poll<std::io::Result<()>> {
413        match self.get_mut() {
414            DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
415            #[cfg(unix)]
416            DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
417        }
418    }
419}
420
421impl tokio::io::AsyncWrite for DualStream {
422    fn poll_write(
423        self: std::pin::Pin<&mut Self>,
424        cx: &mut std::task::Context<'_>,
425        buf: &[u8],
426    ) -> std::task::Poll<std::io::Result<usize>> {
427        match self.get_mut() {
428            DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
429            #[cfg(unix)]
430            DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
431        }
432    }
433
434    fn poll_flush(
435        self: std::pin::Pin<&mut Self>,
436        cx: &mut std::task::Context<'_>,
437    ) -> std::task::Poll<std::io::Result<()>> {
438        match self.get_mut() {
439            DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
440            #[cfg(unix)]
441            DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_flush(cx),
442        }
443    }
444
445    fn poll_shutdown(
446        self: std::pin::Pin<&mut Self>,
447        cx: &mut std::task::Context<'_>,
448    ) -> std::task::Poll<std::io::Result<()>> {
449        match self.get_mut() {
450            DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
451            #[cfg(unix)]
452            DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
453        }
454    }
455}
456
457impl axum::serve::Listener for DualListener {
458    type Io = DualStream;
459    type Addr = DualAddr;
460    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
461        self._accept_axum().await
462    }
463
464    fn local_addr(&self) -> Result<Self::Addr, std::io::Error> {
465        match self {
466            DualListener::Tcp(listener) => Listener::local_addr(listener).map(DualAddr::Tcp),
467            #[cfg(unix)]
468            DualListener::Uds(listener) => Listener::local_addr(listener).map(DualAddr::Uds),
469        }
470    }
471}
472
473const _: () = {
474    use super::DualAddr;
475    use axum::extract::connect_info::Connected;
476    impl Connected<DualAddr> for DualAddr {
477        fn connect_info(remote_addr: DualAddr) -> Self {
478            remote_addr
479        }
480    }
481    use axum::serve;
482
483    impl Connected<serve::IncomingStream<'_, DualListener>> for DualAddr {
484        fn connect_info(stream: serve::IncomingStream<'_, DualListener>) -> Self {
485            stream.remote_addr().clone()
486        }
487    }
488};
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    #[tokio::test]
494    async fn test_tcp_bind() {
495        let listener = DualListener::bind("localhost:8080").await;
496        assert!(listener.is_ok());
497        if let DualListener::Tcp(tcp_listener) = listener.unwrap() {
498            let addr = tcp_listener.local_addr().unwrap();
499            assert_eq!(addr.port(), 8080);
500        } else {
501            panic!("Expected TCP listener");
502        }
503    }
504
505    #[tokio::test]
506    async fn test_uds_bind() {
507        #[cfg(unix)]
508        {
509            let listener = DualListener::bind("/tmp/test.sock").await;
510            assert!(listener.is_ok());
511            if let DualListener::Uds(uds_listener) = listener.unwrap() {
512                let addr = uds_listener.local_addr().unwrap();
513                assert_eq!(
514                    addr.as_pathname().unwrap(),
515                    std::path::Path::new("/tmp/test.sock")
516                );
517            } else {
518                panic!("Expected UDS listener");
519            }
520        }
521    }
522}