1use axum::serve::Listener;
2use std::net::ToSocketAddrs;
3#[cfg(unix)]
4use tokio::net::UnixListener;
5
6#[derive(Debug)]
33pub enum DualListener {
34 Tcp(tokio::net::TcpListener),
36 #[cfg(unix)]
38 Uds(tokio::net::UnixListener),
39}
40
41#[derive(Debug, Clone)]
62#[allow(dead_code)]
63pub enum DualAddr {
64 Tcp(core::net::SocketAddr),
66 #[cfg(unix)]
68 Uds(tokio::net::unix::SocketAddr),
69}
70
71impl From<core::net::SocketAddr> for DualAddr {
72 fn from(addr: core::net::SocketAddr) -> Self {
73 DualAddr::Tcp(addr)
74 }
75}
76
77#[cfg(unix)]
78impl From<tokio::net::unix::SocketAddr> for DualAddr {
79 fn from(addr: tokio::net::unix::SocketAddr) -> Self {
80 DualAddr::Uds(addr)
81 }
82}
83
84impl core::str::FromStr for DualAddr {
85 type Err = std::io::Error;
86
87 fn from_str(s: &str) -> Result<Self, Self::Err> {
88 let unix_like = s.starts_with("/") || s.starts_with("unix:");
89 let has_uds = cfg!(unix);
90 let tcp_like = s.to_socket_addrs().is_ok();
91
92 if unix_like && has_uds && !tcp_like {
93 #[cfg(unix)]
94 {
95 let path = s.trim_start_matches("unix:");
96 let addr = From::from(std::os::unix::net::SocketAddr::from_pathname(path)?);
97 Ok(DualAddr::Uds(addr))
98 }
99 #[cfg(not(unix))]
100 {
101 Err(std::io::Error::new(
102 std::io::ErrorKind::Other,
103 "Unix domain sockets are not supported on this platform",
104 ))
105 }
106 } else if tcp_like {
107 let addr = s.to_socket_addrs()?.next().ok_or_else(|| {
108 std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid TCP address")
109 })?;
110 Ok(DualAddr::Tcp(addr))
111 } else if unix_like && !has_uds {
112 Err(std::io::Error::other(
113 "Unix domain sockets are not supported on this platform",
114 ))
115 } else {
116 Err(std::io::Error::new(
117 std::io::ErrorKind::InvalidInput,
118 "Invalid address format",
119 ))
120 }
121 }
122}
123
124pub trait ToDualAddr {
145 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error>;
152}
153
154impl ToDualAddr for &str {
155 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
156 self.parse()
157 }
158}
159
160impl ToDualAddr for String {
161 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
162 self.as_str().to_dual_addr()
163 }
164}
165
166impl ToDualAddr for core::net::SocketAddr {
167 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
168 Ok(DualAddr::Tcp(*self))
169 }
170}
171
172#[cfg(unix)]
173impl ToDualAddr for tokio::net::unix::SocketAddr {
174 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
175 Ok(DualAddr::Uds(self.clone()))
176 }
177}
178
179impl ToDualAddr for DualAddr {
180 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
181 Ok(self.clone())
182 }
183}
184
185impl ToDualAddr for &DualAddr {
186 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
187 Ok((*self).clone())
188 }
189}
190
191#[cfg(unix)]
192impl ToDualAddr for &std::path::Path {
193 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
194 Ok(DualAddr::Uds(From::from(
195 std::os::unix::net::SocketAddr::from_pathname(self)?,
196 )))
197 }
198}
199
200#[cfg(unix)]
201impl ToDualAddr for std::path::PathBuf {
202 fn to_dual_addr(&self) -> Result<DualAddr, std::io::Error> {
203 self.as_path().to_dual_addr()
204 }
205}
206
207impl DualListener {
208 pub async fn bind<A: ToDualAddr>(address: A) -> Result<Self, std::io::Error> {
247 let address = address.to_dual_addr()?;
248 match address {
249 DualAddr::Tcp(addr) => {
250 let listener = tokio::net::TcpListener::bind(addr).await?;
251 Ok(DualListener::Tcp(listener))
252 }
253 #[cfg(unix)]
254 DualAddr::Uds(ref addr) => {
255 let path = addr.as_pathname().ok_or_else(|| {
256 std::io::Error::new(
257 std::io::ErrorKind::InvalidInput,
258 "UDS address does not have a valid pathname",
259 )
260 })?;
261 let listener = UnixListener::bind(path)?;
262 Ok(DualListener::Uds(listener))
263 }
264 #[cfg(not(unix))]
265 DualAddr::Uds(_) => Err(std::io::Error::new(
266 std::io::ErrorKind::Other,
267 "Unix domain sockets are not supported on this platform",
268 )),
269 }
270 }
271
272 pub async fn accept(&self) -> Result<(DualStream, DualAddr), std::io::Error> {
301 match self {
302 DualListener::Tcp(listener) => {
303 let (stream, addr) = listener.accept().await?;
304 Ok((DualStream::Tcp(stream), DualAddr::Tcp(addr)))
305 }
306 #[cfg(unix)]
307 DualListener::Uds(listener) => {
308 let (stream, addr) = listener.accept().await?;
309 Ok((DualStream::Uds(stream), DualAddr::Uds(addr)))
310 }
311 }
312 }
313
314 pub(crate) fn _accept_unpin(
315 &self,
316 ) -> impl core::future::Future<Output = Result<(DualStream, DualAddr), std::io::Error>>
317 + Unpin
318 + use<'_> {
319 Box::pin(async move {
320 match self {
321 DualListener::Tcp(listener) => {
322 let (stream, addr) = listener.accept().await?;
323 Ok((DualStream::Tcp(stream), DualAddr::Tcp(addr)))
324 }
325 #[cfg(unix)]
326 DualListener::Uds(listener) => {
327 let (stream, addr) = listener.accept().await?;
328 Ok((DualStream::Uds(stream), DualAddr::Uds(addr)))
329 }
330 }
331 })
332 }
333 pub(crate) async fn _accept_axum(&mut self) -> (DualStream, DualAddr) {
334 match self {
335 DualListener::Tcp(listener) => {
336 let (stream, addr) = Listener::accept(listener).await;
337 (DualStream::Tcp(stream), DualAddr::Tcp(addr))
338 }
339 #[cfg(unix)]
340 DualListener::Uds(listener) => {
341 let (stream, addr) = Listener::accept(listener).await;
342 (DualStream::Uds(stream), DualAddr::Uds(addr))
343 }
344 }
345 }
346
347 pub(crate) fn _accept_axum_unpin(
348 &mut self,
349 ) -> impl core::future::Future<Output = (DualStream, DualAddr)> + Unpin + use<'_> {
350 Box::pin(async move {
351 match self {
352 DualListener::Tcp(listener) => {
353 let (stream, addr) = Listener::accept(listener).await;
354 (DualStream::Tcp(stream), DualAddr::Tcp(addr))
355 }
356 #[cfg(unix)]
357 DualListener::Uds(listener) => {
358 let (stream, addr) = Listener::accept(listener).await;
359 (DualStream::Uds(stream), DualAddr::Uds(addr))
360 }
361 }
362 })
363 }
364}
365
366pub enum DualStream {
387 Tcp(tokio::net::TcpStream),
389 #[cfg(unix)]
391 Uds(tokio::net::UnixStream),
392}
393
394impl From<tokio::net::TcpStream> for DualStream {
395 fn from(stream: tokio::net::TcpStream) -> Self {
396 DualStream::Tcp(stream)
397 }
398}
399
400#[cfg(unix)]
401impl From<tokio::net::UnixStream> for DualStream {
402 fn from(stream: tokio::net::UnixStream) -> Self {
403 DualStream::Uds(stream)
404 }
405}
406
407impl tokio::io::AsyncRead for DualStream {
408 fn poll_read(
409 self: std::pin::Pin<&mut Self>,
410 cx: &mut std::task::Context<'_>,
411 buf: &mut tokio::io::ReadBuf<'_>,
412 ) -> std::task::Poll<std::io::Result<()>> {
413 match self.get_mut() {
414 DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
415 #[cfg(unix)]
416 DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
417 }
418 }
419}
420
421impl tokio::io::AsyncWrite for DualStream {
422 fn poll_write(
423 self: std::pin::Pin<&mut Self>,
424 cx: &mut std::task::Context<'_>,
425 buf: &[u8],
426 ) -> std::task::Poll<std::io::Result<usize>> {
427 match self.get_mut() {
428 DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
429 #[cfg(unix)]
430 DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
431 }
432 }
433
434 fn poll_flush(
435 self: std::pin::Pin<&mut Self>,
436 cx: &mut std::task::Context<'_>,
437 ) -> std::task::Poll<std::io::Result<()>> {
438 match self.get_mut() {
439 DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx),
440 #[cfg(unix)]
441 DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_flush(cx),
442 }
443 }
444
445 fn poll_shutdown(
446 self: std::pin::Pin<&mut Self>,
447 cx: &mut std::task::Context<'_>,
448 ) -> std::task::Poll<std::io::Result<()>> {
449 match self.get_mut() {
450 DualStream::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
451 #[cfg(unix)]
452 DualStream::Uds(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
453 }
454 }
455}
456
457impl axum::serve::Listener for DualListener {
458 type Io = DualStream;
459 type Addr = DualAddr;
460 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
461 self._accept_axum().await
462 }
463
464 fn local_addr(&self) -> Result<Self::Addr, std::io::Error> {
465 match self {
466 DualListener::Tcp(listener) => Listener::local_addr(listener).map(DualAddr::Tcp),
467 #[cfg(unix)]
468 DualListener::Uds(listener) => Listener::local_addr(listener).map(DualAddr::Uds),
469 }
470 }
471}
472
473const _: () = {
474 use super::DualAddr;
475 use axum::extract::connect_info::Connected;
476 impl Connected<DualAddr> for DualAddr {
477 fn connect_info(remote_addr: DualAddr) -> Self {
478 remote_addr
479 }
480 }
481 use axum::serve;
482
483 impl Connected<serve::IncomingStream<'_, DualListener>> for DualAddr {
484 fn connect_info(stream: serve::IncomingStream<'_, DualListener>) -> Self {
485 stream.remote_addr().clone()
486 }
487 }
488};
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 #[tokio::test]
494 async fn test_tcp_bind() {
495 let listener = DualListener::bind("localhost:8080").await;
496 assert!(listener.is_ok());
497 if let DualListener::Tcp(tcp_listener) = listener.unwrap() {
498 let addr = tcp_listener.local_addr().unwrap();
499 assert_eq!(addr.port(), 8080);
500 } else {
501 panic!("Expected TCP listener");
502 }
503 }
504
505 #[tokio::test]
506 async fn test_uds_bind() {
507 #[cfg(unix)]
508 {
509 let listener = DualListener::bind("/tmp/test.sock").await;
510 assert!(listener.is_ok());
511 if let DualListener::Uds(uds_listener) = listener.unwrap() {
512 let addr = uds_listener.local_addr().unwrap();
513 assert_eq!(
514 addr.as_pathname().unwrap(),
515 std::path::Path::new("/tmp/test.sock")
516 );
517 } else {
518 panic!("Expected UDS listener");
519 }
520 }
521 }
522}