compio_net/
unix.rs

1use std::{future::Future, io, path::Path};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
6use compio_runtime::{BorrowedBuffer, BufferPool};
7use socket2::{SockAddr, Socket as Socket2, Type};
8
9use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, WriteHalf};
10
11/// A Unix socket server, listening for connections.
12///
13/// You can accept a new connection by using the [`UnixListener::accept`]
14/// method.
15///
16/// # Examples
17///
18/// ```
19/// use compio_io::{AsyncReadExt, AsyncWriteExt};
20/// use compio_net::{UnixListener, UnixStream};
21/// use tempfile::tempdir;
22///
23/// let dir = tempdir().unwrap();
24/// let sock_file = dir.path().join("unix-server.sock");
25///
26/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
27/// let listener = UnixListener::bind(&sock_file).await.unwrap();
28///
29/// let (mut tx, (mut rx, _)) =
30///     futures_util::try_join!(UnixStream::connect(&sock_file), listener.accept()).unwrap();
31///
32/// tx.write_all("test").await.0.unwrap();
33///
34/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
35///
36/// assert_eq!(buf, b"test");
37/// # });
38/// ```
39#[derive(Debug, Clone)]
40pub struct UnixListener {
41    inner: Socket,
42}
43
44impl UnixListener {
45    /// Creates a new [`UnixListener`], which will be bound to the specified
46    /// file path. The file path cannot yet exist, and will be cleaned up
47    /// upon dropping [`UnixListener`]
48    pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
49        Self::bind_addr(&SockAddr::unix(path)?).await
50    }
51
52    /// Creates a new [`UnixListener`] with [`SockAddr`], which will be bound to
53    /// the specified file path. The file path cannot yet exist, and will be
54    /// cleaned up upon dropping [`UnixListener`]
55    pub async fn bind_addr(addr: &SockAddr) -> io::Result<Self> {
56        if !addr.is_unix() {
57            return Err(io::Error::new(
58                io::ErrorKind::InvalidInput,
59                "addr is not unix socket address",
60            ));
61        }
62
63        let socket = Socket::bind(addr, Type::STREAM, None).await?;
64        socket.listen(1024)?;
65        Ok(UnixListener { inner: socket })
66    }
67
68    #[cfg(unix)]
69    /// Creates new UnixListener from a [`std::os::unix::net::UnixListener`].
70    pub fn from_std(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
71        Ok(Self {
72            inner: Socket::from_socket2(Socket2::from(stream))?,
73        })
74    }
75
76    /// Close the socket. If the returned future is dropped before polling, the
77    /// socket won't be closed.
78    pub fn close(self) -> impl Future<Output = io::Result<()>> {
79        self.inner.close()
80    }
81
82    /// Accepts a new incoming connection from this listener.
83    ///
84    /// This function will yield once a new Unix domain socket connection
85    /// is established. When established, the corresponding [`UnixStream`] and
86    /// will be returned.
87    pub async fn accept(&self) -> io::Result<(UnixStream, SockAddr)> {
88        let (socket, addr) = self.inner.accept().await?;
89        let stream = UnixStream { inner: socket };
90        Ok((stream, addr))
91    }
92
93    /// Returns the local address that this listener is bound to.
94    pub fn local_addr(&self) -> io::Result<SockAddr> {
95        self.inner.local_addr()
96    }
97}
98
99impl_raw_fd!(UnixListener, socket2::Socket, inner, socket);
100
101/// A Unix stream between two local sockets on Windows & WSL.
102///
103/// A Unix stream can either be created by connecting to an endpoint, via the
104/// `connect` method, or by accepting a connection from a listener.
105///
106/// # Examples
107///
108/// ```no_run
109/// use compio_io::AsyncWrite;
110/// use compio_net::UnixStream;
111///
112/// # compio_runtime::Runtime::new().unwrap().block_on(async {
113/// // Connect to a peer
114/// let mut stream = UnixStream::connect("unix-server.sock").await.unwrap();
115///
116/// // Write some data.
117/// stream.write("hello world!").await.unwrap();
118/// # })
119/// ```
120#[derive(Debug, Clone)]
121pub struct UnixStream {
122    inner: Socket,
123}
124
125impl UnixStream {
126    /// Opens a Unix connection to the specified file path. There must be a
127    /// [`UnixListener`] or equivalent listening on the corresponding Unix
128    /// domain socket to successfully connect and return a `UnixStream`.
129    pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
130        Self::connect_addr(&SockAddr::unix(path)?).await
131    }
132
133    /// Opens a Unix connection to the specified address. There must be a
134    /// [`UnixListener`] or equivalent listening on the corresponding Unix
135    /// domain socket to successfully connect and return a `UnixStream`.
136    pub async fn connect_addr(addr: &SockAddr) -> io::Result<Self> {
137        if !addr.is_unix() {
138            return Err(io::Error::new(
139                io::ErrorKind::InvalidInput,
140                "addr is not unix socket address",
141            ));
142        }
143
144        #[cfg(windows)]
145        let socket = {
146            let new_addr = empty_unix_socket();
147            Socket::bind(&new_addr, Type::STREAM, None).await?
148        };
149        #[cfg(unix)]
150        let socket = {
151            use socket2::Domain;
152            Socket::new(Domain::UNIX, Type::STREAM, None).await?
153        };
154        socket.connect_async(addr).await?;
155        let unix_stream = UnixStream { inner: socket };
156        Ok(unix_stream)
157    }
158
159    #[cfg(unix)]
160    /// Creates new UnixStream from a [`std::os::unix::net::UnixStream`].
161    pub fn from_std(stream: std::os::unix::net::UnixStream) -> io::Result<Self> {
162        Ok(Self {
163            inner: Socket::from_socket2(Socket2::from(stream))?,
164        })
165    }
166
167    /// Close the socket. If the returned future is dropped before polling, the
168    /// socket won't be closed.
169    pub fn close(self) -> impl Future<Output = io::Result<()>> {
170        self.inner.close()
171    }
172
173    /// Returns the socket path of the remote peer of this connection.
174    pub fn peer_addr(&self) -> io::Result<SockAddr> {
175        #[allow(unused_mut)]
176        let mut addr = self.inner.peer_addr()?;
177        #[cfg(windows)]
178        {
179            fix_unix_socket_length(&mut addr);
180        }
181        Ok(addr)
182    }
183
184    /// Returns the socket path of the local half of this connection.
185    pub fn local_addr(&self) -> io::Result<SockAddr> {
186        self.inner.local_addr()
187    }
188
189    /// Splits a [`UnixStream`] into a read half and a write half, which can be
190    /// used to read and write the stream concurrently.
191    ///
192    /// This method is more efficient than
193    /// [`into_split`](UnixStream::into_split), but the halves cannot
194    /// be moved into independently spawned tasks.
195    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
196        crate::split(self)
197    }
198
199    /// Splits a [`UnixStream`] into a read half and a write half, which can be
200    /// used to read and write the stream concurrently.
201    ///
202    /// Unlike [`split`](UnixStream::split), the owned halves can be moved to
203    /// separate tasks, however this comes at the cost of a heap allocation.
204    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
205        crate::into_split(self)
206    }
207
208    /// Create [`PollFd`] from inner socket.
209    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
210        self.inner.to_poll_fd()
211    }
212
213    /// Create [`PollFd`] from inner socket.
214    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
215        self.inner.into_poll_fd()
216    }
217}
218
219impl AsyncRead for UnixStream {
220    #[inline]
221    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
222        (&*self).read(buf).await
223    }
224
225    #[inline]
226    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
227        (&*self).read_vectored(buf).await
228    }
229}
230
231impl AsyncRead for &UnixStream {
232    #[inline]
233    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
234        self.inner.recv(buf).await
235    }
236
237    #[inline]
238    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
239        self.inner.recv_vectored(buf).await
240    }
241}
242
243impl AsyncReadManaged for UnixStream {
244    type Buffer<'a> = BorrowedBuffer<'a>;
245    type BufferPool = BufferPool;
246
247    async fn read_managed<'a>(
248        &mut self,
249        buffer_pool: &'a Self::BufferPool,
250        len: usize,
251    ) -> io::Result<Self::Buffer<'a>> {
252        (&*self).read_managed(buffer_pool, len).await
253    }
254}
255
256impl AsyncReadManaged for &UnixStream {
257    type Buffer<'a> = BorrowedBuffer<'a>;
258    type BufferPool = BufferPool;
259
260    async fn read_managed<'a>(
261        &mut self,
262        buffer_pool: &'a Self::BufferPool,
263        len: usize,
264    ) -> io::Result<Self::Buffer<'a>> {
265        self.inner.recv_managed(buffer_pool, len as _).await
266    }
267}
268
269impl AsyncWrite for UnixStream {
270    #[inline]
271    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
272        (&*self).write(buf).await
273    }
274
275    #[inline]
276    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
277        (&*self).write_vectored(buf).await
278    }
279
280    #[inline]
281    async fn flush(&mut self) -> io::Result<()> {
282        (&*self).flush().await
283    }
284
285    #[inline]
286    async fn shutdown(&mut self) -> io::Result<()> {
287        (&*self).shutdown().await
288    }
289}
290
291impl AsyncWrite for &UnixStream {
292    #[inline]
293    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
294        self.inner.send(buf).await
295    }
296
297    #[inline]
298    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
299        self.inner.send_vectored(buf).await
300    }
301
302    #[inline]
303    async fn flush(&mut self) -> io::Result<()> {
304        Ok(())
305    }
306
307    #[inline]
308    async fn shutdown(&mut self) -> io::Result<()> {
309        self.inner.shutdown().await
310    }
311}
312
313impl Splittable for UnixStream {
314    type ReadHalf = OwnedReadHalf<Self>;
315    type WriteHalf = OwnedWriteHalf<Self>;
316
317    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
318        crate::into_split(self)
319    }
320}
321
322impl<'a> Splittable for &'a UnixStream {
323    type ReadHalf = ReadHalf<'a, UnixStream>;
324    type WriteHalf = WriteHalf<'a, UnixStream>;
325
326    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
327        crate::split(self)
328    }
329}
330
331impl_raw_fd!(UnixStream, socket2::Socket, inner, socket);
332
333#[cfg(windows)]
334#[inline]
335fn empty_unix_socket() -> SockAddr {
336    use windows_sys::Win32::Networking::WinSock::{AF_UNIX, SOCKADDR_UN};
337
338    // SAFETY: the length is correct
339    unsafe {
340        SockAddr::try_init(|addr, len| {
341            let addr: *mut SOCKADDR_UN = addr.cast();
342            std::ptr::write(
343                addr,
344                SOCKADDR_UN {
345                    sun_family: AF_UNIX,
346                    sun_path: [0; 108],
347                },
348            );
349            std::ptr::write(len, 3);
350            Ok(())
351        })
352    }
353    // it is always Ok
354    .unwrap()
355    .1
356}
357
358// The peer addr returned after ConnectEx is buggy. It contains bytes that
359// should not belong to the address. Luckily a unix path should not contain `\0`
360// until the end. We can determine the path ending by that.
361#[cfg(windows)]
362#[inline]
363fn fix_unix_socket_length(addr: &mut SockAddr) {
364    use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN;
365
366    // SAFETY: cannot construct non-unix socket address in safe way.
367    let unix_addr: &SOCKADDR_UN = unsafe { &*addr.as_ptr().cast() };
368    let addr_len = match std::ffi::CStr::from_bytes_until_nul(&unix_addr.sun_path) {
369        Ok(str) => str.to_bytes_with_nul().len() + 2,
370        Err(_) => std::mem::size_of::<SOCKADDR_UN>(),
371    };
372    unsafe {
373        addr.set_length(addr_len as _);
374    }
375}