asyncfd/
lib.rs

1use std::collections::VecDeque;
2use std::ffi::{c_int, c_void};
3use std::io::{Error, Result};
4use std::os::fd::{AsRawFd, IntoRawFd, RawFd};
5use std::os::unix::net::UnixStream;
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{ready, Context, Poll};
9
10use tokio::io::unix::AsyncFd;
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
13
14mod header;
15pub mod split;
16pub mod split_owned;
17
18/// A wrapper around a `UnixStream` that allows file descriptors to be
19/// sent and received with messages.  Implements `AsyncRead` and
20/// `AsyncWrite` such that standard asynchronous reading and writing
21/// operations and helpers may be used.
22pub struct UnixFdStream<T: AsRawFd> {
23    inner: AsyncFd<T>,
24    incoming_fds: Mutex<VecDeque<RawFd>>,
25    outgoing_tx: UnboundedSender<RawFd>,
26    outgoing_rx: Option<UnboundedReceiver<RawFd>>,
27    max_read_fds: usize,
28}
29
30/// This is the trait required to implement AsyncWrite for a type.
31pub trait Shutdown {
32    fn shutdown(&self, how: std::net::Shutdown) -> Result<()>;
33}
34
35impl Shutdown for UnixStream {
36    fn shutdown(&self, how: std::net::Shutdown) -> Result<()> {
37        UnixStream::shutdown(self, how)
38    }
39}
40
41/// This is the trait required to create a UnixFdStream as it needs to
42/// be non-blocking before it can be used.
43pub trait NonBlocking {
44    fn set_nonblocking(&self, nonblocking: bool) -> Result<()>;
45}
46
47impl NonBlocking for UnixStream {
48    fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
49        UnixStream::set_nonblocking(&self, nonblocking)
50    }
51}
52
53pub(crate) unsafe fn close_fds<T: IntoIterator<Item = RawFd>>(fds: T) {
54    for fd in fds.into_iter() {
55        libc::close(fd);
56    }
57}
58
59impl<T: AsRawFd + NonBlocking> UnixFdStream<T> {
60    /// Create a new `UnixFdStream` from a `UnixStream` which is also
61    /// configured to read up to `max_read_fds` for each read from the
62    /// socket.
63    ///
64    /// The file descriptors that are transferred are buffered in a
65    /// `Vec<RawFd>`, but only so many will have space made for them
66    /// in the receiving header as configured by `max_read_fds`, other
67    /// file descriptors sent beyond this limit will be discarded by the
68    /// kernel.  We do not check for the MSG_CTRUNC flag, therefore this
69    /// will be a silent discard.
70    pub fn new(unix: T, max_read_fds: usize) -> Result<Self> {
71        unix.set_nonblocking(true)?;
72        let (outgoing_tx, outgoing_rx) = tokio::sync::mpsc::unbounded_channel();
73        Ok(Self {
74            inner: AsyncFd::new(unix)?,
75            incoming_fds: Mutex::new(VecDeque::new()),
76            outgoing_tx,
77            outgoing_rx: Some(outgoing_rx),
78            max_read_fds,
79        })
80    }
81}
82
83impl<T: AsRawFd> UnixFdStream<T> {
84    pub fn split<'a>(
85        &'a mut self,
86    ) -> (
87        crate::split::ReadHalf<'a, T>,
88        crate::split::WriteHalf<'a, T>,
89    ) {
90        let read =
91            crate::split::ReadHalf::<T>::new(&self.inner, &self.incoming_fds, &self.max_read_fds);
92        let write = crate::split::WriteHalf::<T>::new(
93            &self.inner,
94            &self.outgoing_tx,
95            self.outgoing_rx.as_mut().unwrap(),
96        );
97        (read, write)
98    }
99
100    pub fn into_split(
101        mut self,
102    ) -> (
103        crate::split_owned::OwnedReadHalf<T>,
104        crate::split_owned::OwnedWriteHalf<T>,
105    ) {
106        let rx: UnboundedReceiver<i32> = self.outgoing_rx.take().unwrap();
107        let own_self = Arc::new(self);
108        let write = crate::split_owned::OwnedWriteHalf::new(
109            own_self.clone(),
110            own_self.outgoing_tx.clone(),
111            rx,
112        );
113        (crate::split_owned::OwnedReadHalf::new(own_self), write)
114    }
115
116    /// Push a file descriptor to be written with the next message that
117    /// is written to this stream.  The ownership is transferred and the
118    /// file descriptor is either closed when the message is sent or this
119    /// instance is dropped.
120    pub fn push_outgoing_fd<F: IntoRawFd>(&self, fd: F) {
121        if let Err(fd) = self.outgoing_tx.send(fd.into_raw_fd()) {
122            // This should never happen, but implemented for completeness.
123            // SAFETY: We just failed to push this file descriptor, so we have to
124            //         close it.
125            unsafe {
126                libc::close(fd.0);
127            }
128        }
129    }
130
131    /// Wait for the underlying UnixStream to become readable.
132    pub async fn readable(&self) -> Result<()> {
133        self.inner.readable().await?.retain_ready();
134        Ok(())
135    }
136
137    /// Get the most recent file descriptor that was read with a message.
138    pub fn pop_incoming_fd(&self) -> Option<RawFd> {
139        if let Ok(mut guard) = self.incoming_fds.lock() {
140            guard.pop_front()
141        } else {
142            None
143        }
144    }
145
146    /// Get the number of file descriptors in the incoming queue.
147    pub fn incoming_count(&self) -> usize {
148        self.incoming_fds
149            .lock()
150            .map(|guard| guard.len())
151            .unwrap_or(0)
152    }
153
154    fn write_simple(socket: RawFd, buf: &[u8]) -> Result<usize> {
155        // SAFETY: The socket is owned by us and the buffer is of known size.
156        let rv = unsafe { libc::send(socket, buf.as_ptr() as *const c_void, buf.len(), 0) };
157        if rv < 0 {
158            return Err(std::io::Error::last_os_error());
159        }
160        Ok(rv as usize)
161    }
162
163    fn add_to_outgoing(&mut self, mut fds: Vec<RawFd>) {
164        // Just in case there were other file descriptors added, pull them from the channel.
165        while let Ok(fd) = self.outgoing_rx.as_mut().unwrap().try_recv() {
166            fds.push(fd);
167        }
168        // Push all the file descriptors to the channel in order.
169        for fd in fds.into_iter() {
170            if let Err(fd) = self.outgoing_tx.send(fd) {
171                // This is impossible as we own the rx, but just for completeness.
172                // SAFETY: We own this file descriptor and are about to drop it on the
173                //         floor, so it's safe to close it.
174                unsafe {
175                    libc::close(fd.0);
176                }
177            }
178        }
179    }
180
181    fn raw_write(socket: RawFd, outgoing_fds: &[RawFd], buf: &[u8]) -> Result<usize> {
182        if outgoing_fds.is_empty() {
183            return Self::write_simple(socket, buf);
184        }
185        let header = crate::header::Header::new(outgoing_fds.len())?;
186        let mut iov = libc::iovec {
187            iov_base: buf.as_ptr() as *mut c_void,
188            iov_len: buf.len(),
189        };
190        // SAFETY: Not really sure why this method is unsafe.
191        let control_length = unsafe { libc::CMSG_LEN(header.data_length as u32) } as _;
192        let msg = libc::msghdr {
193            msg_iov: &mut iov,
194            msg_iovlen: 1,
195            msg_name: std::ptr::null_mut(),
196            msg_namelen: 0,
197            msg_control: header.as_ptr(),
198            msg_controllen: control_length,
199            msg_flags: 0,
200        };
201        // SAFETY: We have constructed the msghdr correctly, so this will point to
202        //         the allocated memory within `header`.
203        let cmsg = unsafe { &mut *libc::CMSG_FIRSTHDR(&msg) };
204        cmsg.cmsg_len = control_length;
205        cmsg.cmsg_type = libc::SCM_RIGHTS;
206        cmsg.cmsg_level = libc::SOL_SOCKET;
207        // SAFETY: We have allocated correctly aligned memory, so this will point to
208        //         the allocated memory within `header`.
209        let mut data = unsafe { libc::CMSG_DATA(cmsg) as *mut c_int };
210        for fd in outgoing_fds {
211            // SAFETY: We have a valid pointer to `header` and now we are copying
212            //         the data that we created space for into it.
213            data = unsafe {
214                std::ptr::write_unaligned(data, *fd as c_int);
215                data.add(1)
216            };
217        }
218        // SAFETY: We just set up the message to send, so we're all safe to attempt to
219        //         send it, also the socket that we are sending on is owned by us.
220        let rv = unsafe { libc::sendmsg(socket, &msg, 0) };
221        if rv < 0 {
222            return Err(std::io::Error::last_os_error());
223        }
224        Ok(rv as usize)
225    }
226
227    fn read_simple(fd: RawFd, buf: &mut [u8]) -> Result<usize> {
228        // SAFETY: The socket is owned by us and the buffer is of known size.
229        let rv = unsafe { libc::recv(fd, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
230        if rv < 0 {
231            return Err(std::io::Error::last_os_error());
232        }
233        Ok(rv as usize)
234    }
235
236    fn read_fds(msg: &libc::msghdr) -> Result<VecDeque<RawFd>> {
237        // SAFETY: We set up the buffers correctly and we assume the kernel
238        //         passes us safe data.
239        let mut cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(msg) };
240        let mut read_fds = VecDeque::<RawFd>::new();
241        while !cmsg_ptr.is_null() {
242            // SAFETY: We just checked for NULL, the header was initialised to zero
243            //         and we assume the kernel passes us safe data.
244            let cmsg = unsafe { &*cmsg_ptr };
245            if cmsg.cmsg_level == libc::SOL_SOCKET && cmsg.cmsg_type == libc::SCM_RIGHTS {
246                // SAFETY: We just checked the header type and assume that the kernel
247                //         is passing us valid data.
248                let mut data = unsafe { libc::CMSG_DATA(cmsg) as *const c_int };
249                // SAFETY: Calculating a past the end pointer that is only accessed in
250                //         an unaligned safe manner.
251                let data_end =
252                    unsafe { (cmsg_ptr as *const u8).add(cmsg.cmsg_len as usize) as *const i32 };
253                while data < data_end {
254                    // SAFETY: We are checking that the data is within the header size
255                    //         each iteration.
256                    let fd = unsafe { std::ptr::read_unaligned(data) };
257                    // SAFETY: The kernel just passed us this file descriptor.
258                    let result = unsafe { libc::fcntl(fd, libc::F_SETFD, libc::FD_CLOEXEC) };
259                    read_fds.push_back(fd);
260                    if result < 0 {
261                        // SAFETY: We have just read these FDs, so it's safe to close them.
262                        unsafe { close_fds(read_fds) };
263                        return Err(Error::last_os_error());
264                    }
265                    // SAFETY: We are just about to test this against the past-the-end pointer
266                    //         as we go around the loop.
267                    data = unsafe { data.add(1) };
268                }
269            }
270            // SAFETY: We set up the buffers correctly and we assume the kernel
271            //         passes us safe data.
272            cmsg_ptr = unsafe { libc::CMSG_NXTHDR(msg, cmsg_ptr) };
273        }
274        Ok(read_fds)
275    }
276
277    fn raw_read(
278        max_read_fds: usize,
279        fd: RawFd,
280        buf: &mut [u8],
281    ) -> Result<(usize, VecDeque<RawFd>)> {
282        // Shortcut in case this was used without any file descriptor
283        // read buffer, maybe the user just wants to send file descriptors.
284        if max_read_fds == 0 {
285            return Self::read_simple(fd, buf).map(|bytes| (bytes, VecDeque::new()));
286        }
287        let header = crate::header::Header::new(max_read_fds)?;
288        let mut iov = libc::iovec {
289            iov_base: buf.as_mut_ptr() as *mut c_void,
290            iov_len: buf.len(),
291        };
292        // SAFETY: Just calculating the length of the header to send.
293        let control_length = unsafe { libc::CMSG_LEN(header.header_length as u32) } as _;
294        let mut msg = libc::msghdr {
295            msg_name: std::ptr::null_mut(),
296            msg_namelen: 0,
297            msg_iov: &mut iov,
298            msg_iovlen: 1,
299            msg_control: header.as_ptr(),
300            msg_controllen: control_length,
301            msg_flags: 0,
302        };
303        // SAFETY: We own the socket and have just created and set up the message
304        //         headers correctly.
305        let read_bytes = match unsafe { libc::recvmsg(fd, &mut msg, 0) } {
306            0 => return Ok((0, VecDeque::new())),
307            rv if rv < 0 => Err(Error::last_os_error()),
308            rv => Ok(rv as usize),
309        }?;
310        let read_fds = UnixFdStream::<T>::read_fds(&msg)?;
311        Ok((read_bytes, read_fds))
312    }
313}
314
315impl<T: AsRawFd> Drop for UnixFdStream<T> {
316    fn drop(&mut self) {
317        if let Some(outgoing_rx) = &mut self.outgoing_rx {
318            while let Ok(fd) = outgoing_rx.try_recv() {
319                // SAFETY: It we own these file descriptors, so it's safe for us to close them.
320                unsafe {
321                    libc::close(fd);
322                };
323            }
324        }
325
326        self.incoming_fds.clear_poison();
327        let mut fds = VecDeque::new();
328        if let Ok(mut guard) = self.incoming_fds.lock() {
329            std::mem::swap(&mut fds, &mut *guard);
330        }
331        // SAFETY: It we own these file descriptors, so it's safe for us to close them.
332        unsafe { close_fds(fds) };
333    }
334}
335
336impl<T: AsRawFd> AsyncRead for UnixFdStream<T> {
337    fn poll_read(
338        self: Pin<&mut Self>,
339        cx: &mut Context<'_>,
340        buf: &mut ReadBuf<'_>,
341    ) -> Poll<Result<()>> {
342        loop {
343            let mut guard = ready!(self.inner.poll_read_ready(cx))?;
344
345            let unfilled = buf.initialize_unfilled();
346            match guard
347                .try_io(|inner| Self::raw_read(self.max_read_fds, inner.as_raw_fd(), unfilled))
348            {
349                Ok(Ok((len, mut read_fds))) => {
350                    if let Ok(mut guard) = self.incoming_fds.lock() {
351                        guard.append(&mut read_fds);
352                    } else {
353                        // SAFETY: We own the file descriptors, so it's safe to close them.
354                        unsafe {
355                            close_fds(read_fds);
356                        }
357                    }
358                    buf.advance(len);
359                    return Poll::Ready(Ok(()));
360                }
361                Ok(Err(err)) => return Poll::Ready(Err(err)),
362                Err(_would_block) => continue,
363            }
364        }
365    }
366}
367
368impl<T: AsRawFd + Shutdown + Unpin> AsyncWrite for UnixFdStream<T> {
369    fn poll_write(
370        mut self: Pin<&mut Self>,
371        cx: &mut Context<'_>,
372        buf: &[u8],
373    ) -> Poll<std::result::Result<usize, std::io::Error>> {
374        let mut outgoing_fds = Vec::<RawFd>::new();
375        loop {
376            while let Ok(fd) = self.outgoing_rx.as_mut().unwrap().try_recv() {
377                outgoing_fds.push(fd);
378            }
379            let mut guard = match self.inner.poll_write_ready(cx) {
380                Poll::Ready(Ok(guard)) => guard,
381                Poll::Ready(Err(err)) => {
382                    self.add_to_outgoing(outgoing_fds);
383                    return Poll::Ready(Err(err));
384                }
385                Poll::Pending => {
386                    self.add_to_outgoing(outgoing_fds);
387                    return Poll::Pending;
388                }
389            };
390            match guard.try_io(|inner| {
391                UnixFdStream::<UnixStream>::raw_write(inner.as_raw_fd(), &outgoing_fds, buf)
392            }) {
393                Ok(Ok(bytes)) => {
394                    // SAFETY: We own the file descriptors, so it's safe to close them.
395                    unsafe {
396                        close_fds(outgoing_fds);
397                    }
398                    return Poll::Ready(Ok(bytes));
399                }
400                Ok(Err(err)) => {
401                    self.add_to_outgoing(outgoing_fds);
402                    return Poll::Ready(Err(err));
403                }
404                Err(_would_block) => continue,
405            }
406        }
407    }
408
409    fn poll_flush(
410        self: Pin<&mut Self>,
411        _cx: &mut Context<'_>,
412    ) -> Poll<std::result::Result<(), std::io::Error>> {
413        Poll::Ready(Ok(()))
414    }
415
416    fn poll_shutdown(
417        self: Pin<&mut Self>,
418        _cx: &mut Context<'_>,
419    ) -> Poll<std::result::Result<(), std::io::Error>> {
420        Poll::Ready(Shutdown::shutdown(
421            self.inner.get_ref(),
422            std::net::Shutdown::Write,
423        ))
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use std::os::fd::FromRawFd;
430
431    use tokio::io::AsyncBufReadExt;
432    use tokio::io::AsyncWriteExt;
433
434    use crate::UnixFdStream;
435
436    #[tokio::test]
437    async fn send_fd() {
438        let (first, second) = std::os::unix::net::UnixStream::pair().unwrap();
439        let sender = tokio::spawn(async move {
440            let mut first = UnixFdStream::new(first, 0).unwrap();
441            let (third, fourth) = std::os::unix::net::UnixStream::pair().unwrap();
442            let mut third = tokio::net::UnixStream::from_std(third).unwrap();
443            first.push_outgoing_fd(fourth);
444            first.write_all(b"test\n").await.unwrap();
445            first.shutdown().await.unwrap();
446            third.write_all(b"test\n").await.unwrap();
447            third.shutdown().await.unwrap();
448            // If we drop third before receiver has finished reading then the test is not
449            // stable, therefore we keep alive until the receiver drops its end.
450            let _ = third.readable().await;
451        });
452        let receiver = tokio::spawn(async move {
453            let second = tokio::io::BufReader::new(UnixFdStream::new(second, 4).unwrap());
454            let mut lines = second.lines();
455            assert_eq!(Some("test"), lines.next_line().await.unwrap().as_deref());
456            assert_eq!(1, lines.get_ref().get_ref().incoming_count());
457            let fourth: std::os::unix::net::UnixStream = unsafe {
458                std::os::unix::net::UnixStream::from_raw_fd(
459                    lines.get_ref().get_ref().pop_incoming_fd().unwrap(),
460                )
461            };
462            let fourth =
463                tokio::io::BufReader::new(tokio::net::UnixStream::from_std(fourth).unwrap());
464            assert_eq!(
465                Some("test"),
466                fourth.lines().next_line().await.unwrap().as_deref()
467            );
468        });
469        let (send_result, receive_result) = tokio::join!(sender, receiver);
470        send_result.unwrap();
471        receive_result.unwrap();
472    }
473}