hyperdriver/stream/
unix.rs

1//! Unix Stream implementation with better address semantics for servers.
2//!
3//! This module provides a `UnixStream` type that wraps `tokio::net::UnixStream` with
4//! better address semantics for servers. When a server accepts a connection, it
5//! returns the associated `SocketAddr` along side the stream. On some platforms,
6//! this information is not available after the connection is established via
7//! `UnixStream::peer_addr`. This module provides a way to retain this information
8//! for the lifetime of the stream.
9
10use std::fmt;
11use std::io;
12use std::ops::Deref;
13use std::ops::DerefMut;
14use std::path::Path;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use camino::Utf8Path;
19use camino::Utf8PathBuf;
20use tokio::io::{AsyncRead, AsyncWrite};
21#[cfg(feature = "server")]
22pub use tokio::net::UnixListener;
23
24use crate::info::HasConnectionInfo;
25#[cfg(feature = "server")]
26use crate::server::Accept;
27
28/// Connection address for a unix domain socket.
29#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
30pub struct UnixAddr {
31    path: Option<Utf8PathBuf>,
32}
33
34impl UnixAddr {
35    /// Does this socket have a name
36    pub fn is_named(&self) -> bool {
37        self.path.is_some()
38    }
39
40    /// Get the path of this socket.
41    pub fn path(&self) -> Option<&Utf8Path> {
42        self.path.as_deref()
43    }
44
45    /// Create a new address from a path.
46    pub fn from_pathbuf(path: Utf8PathBuf) -> Self {
47        Self { path: Some(path) }
48    }
49
50    /// Create a new address without a path.
51    pub fn unnamed() -> Self {
52        Self { path: None }
53    }
54}
55
56impl fmt::Display for UnixAddr {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        if let Some(path) = self.path() {
59            write!(f, "unix://{}", path)
60        } else {
61            write!(f, "unix://")
62        }
63    }
64}
65
66impl TryFrom<std::os::unix::net::SocketAddr> for UnixAddr {
67    type Error = io::Error;
68    fn try_from(addr: std::os::unix::net::SocketAddr) -> Result<Self, Self::Error> {
69        Ok(Self {
70            path: addr
71                .as_pathname()
72                .map(|p| {
73                    Utf8Path::from_path(p).ok_or_else(|| {
74                        io::Error::new(io::ErrorKind::InvalidData, "not a utf-8 path")
75                    })
76                })
77                .transpose()?
78                .map(|path| path.to_owned()),
79        })
80    }
81}
82
83impl TryFrom<tokio::net::unix::SocketAddr> for UnixAddr {
84    type Error = io::Error;
85    fn try_from(addr: tokio::net::unix::SocketAddr) -> Result<Self, Self::Error> {
86        Ok(Self {
87            path: addr
88                .as_pathname()
89                .map(|p| {
90                    Utf8Path::from_path(p).ok_or_else(|| {
91                        io::Error::new(io::ErrorKind::InvalidData, "not a utf-8 path")
92                    })
93                })
94                .transpose()?
95                .map(|path| path.to_owned()),
96        })
97    }
98}
99
100/// A Unix Stream, wrapping `tokio::net::UnixStream` with better
101/// address semantics for servers.
102#[pin_project::pin_project]
103pub struct UnixStream {
104    #[pin]
105    stream: tokio::net::UnixStream,
106    remote: Option<UnixAddr>,
107}
108
109impl fmt::Debug for UnixStream {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        self.stream.fmt(f)
112    }
113}
114
115impl UnixStream {
116    /// Connect to a remote address. See `tokio::net::UnixStream::connect`.
117    pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
118        let path = path.as_ref();
119        let stream = tokio::net::UnixStream::connect(path).await?;
120        Ok(Self::new(
121            stream,
122            Some(UnixAddr::from_pathbuf(
123                Utf8PathBuf::from_path_buf(path.to_path_buf()).map_err(|path| {
124                    io::Error::new(
125                        io::ErrorKind::InvalidInput,
126                        format!("unix path is not utf-8: {}", path.display()),
127                    )
128                })?,
129            )),
130        ))
131    }
132
133    /// Create a pair of connected `UnixStream`s. See `tokio::net::UnixStream::pair`.
134    pub fn pair() -> io::Result<(Self, Self)> {
135        let (a, b) = tokio::net::UnixStream::pair()?;
136        Ok((
137            Self::new(a, Some(UnixAddr::unnamed())),
138            Self::new(b, Some(UnixAddr::unnamed())),
139        ))
140    }
141
142    /// Create a new `UnixStream` from an existing `tokio::net::UnixStream` for a
143    /// connection. Most of the time, the remote addr should also be passed here,
144    /// but there may be cases when you are handed the stream without the remote
145    /// addr.
146    pub fn new(inner: tokio::net::UnixStream, remote: Option<UnixAddr>) -> Self {
147        Self {
148            stream: inner,
149            remote,
150        }
151    }
152
153    /// Local address of the connection. See `tokio::net::UnixStream::local_addr`.
154    pub fn local_addr(&self) -> io::Result<UnixAddr> {
155        self.stream.local_addr().and_then(UnixAddr::try_from)
156    }
157
158    /// Remote address of the connection. See `tokio::net::UnixStream::peer_addr`.
159    ///
160    /// For servers, this will return the remote address provided when creating the stream,
161    /// instead of an `io::Error`.
162    pub fn peer_addr(&self) -> io::Result<UnixAddr> {
163        match &self.remote {
164            Some(addr) => Ok(addr.clone()),
165            None => self.stream.peer_addr().and_then(UnixAddr::try_from),
166        }
167    }
168
169    /// Unwraps the `UnixStream`, returning the inner `tokio::net::UnixStream`.
170    pub fn into_inner(self) -> tokio::net::UnixStream {
171        self.stream
172    }
173}
174
175impl Deref for UnixStream {
176    type Target = tokio::net::UnixStream;
177    fn deref(&self) -> &Self::Target {
178        &self.stream
179    }
180}
181
182impl DerefMut for UnixStream {
183    fn deref_mut(&mut self) -> &mut Self::Target {
184        &mut self.stream
185    }
186}
187
188impl HasConnectionInfo for UnixStream {
189    type Addr = UnixAddr;
190    fn info(&self) -> crate::info::ConnectionInfo<Self::Addr> {
191        let remote_addr = self
192            .peer_addr()
193            .expect("peer_addr is available for unix stream");
194        let local_addr = self
195            .local_addr()
196            .expect("local_addr is available for unix stream");
197
198        crate::info::ConnectionInfo {
199            local_addr,
200            remote_addr,
201        }
202    }
203}
204
205#[cfg(feature = "client")]
206impl crate::client::pool::PoolableStream for UnixStream {
207    fn can_share(&self) -> bool {
208        false
209    }
210}
211
212impl AsyncRead for UnixStream {
213    fn poll_read(
214        self: Pin<&mut Self>,
215        cx: &mut Context<'_>,
216        buf: &mut tokio::io::ReadBuf<'_>,
217    ) -> Poll<io::Result<()>> {
218        self.project().stream.poll_read(cx, buf)
219    }
220}
221
222impl AsyncWrite for UnixStream {
223    fn poll_write(
224        self: Pin<&mut Self>,
225        cx: &mut Context<'_>,
226        buf: &[u8],
227    ) -> Poll<Result<usize, io::Error>> {
228        self.project().stream.poll_write(cx, buf)
229    }
230
231    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
232        self.project().stream.poll_flush(cx)
233    }
234
235    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
236        self.project().stream.poll_shutdown(cx)
237    }
238
239    fn poll_write_vectored(
240        self: Pin<&mut Self>,
241        cx: &mut Context<'_>,
242        bufs: &[io::IoSlice<'_>],
243    ) -> Poll<Result<usize, io::Error>> {
244        self.project().stream.poll_write_vectored(cx, bufs)
245    }
246
247    fn is_write_vectored(&self) -> bool {
248        self.stream.is_write_vectored()
249    }
250}
251
252#[cfg(feature = "server")]
253impl Accept for UnixListener {
254    type Conn = UnixStream;
255    type Error = io::Error;
256
257    fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Self::Conn>> {
258        UnixListener::poll_accept(self.get_mut(), cx).map(|res| {
259            res.and_then(|(stream, remote)| Ok(UnixStream::new(stream, Some(remote.try_into()?))))
260        })
261    }
262}