withfd/
lib.rs

1//! `withfd` allows passing file descriptors through Unix sockets.
2//!
3//! This crate provides adapters for `std::os::unix::net::UnixStream` and
4//! `tokio::net::UnixStream` (requires the `tokio` feature) that allow passing
5//! file descriptors through them.
6//!
7//! The adapter allows you to keep using the ordinary `Read` and `Write` (or
8//! `AsyncRead` and `AsyncWrite` with the `tokio` feature) interfaces. File
9//! descriptors are received and stored as you read, This is different from
10//! other similar crates like [`passfd`](https://crates.io/crates/passfd)
11//! or [`sendfd`](https://crates.io/crates/sendfd). This is to address the
12//! problem where, if you use ordinary read on the `UnixStream` when the other
13//! end has sent a file descriptor, the file descriptor will be dropped. This
14//! adapter ensures there is no file descriptors being lost.
15//!
16//! # Example
17//!
18//! Process 1:
19//!
20//! ```no_run
21//! use std::{
22//!     fs::File,
23//!     os::unix::{io::AsFd, net::UnixListener},
24//! };
25//!
26//! use withfd::WithFdExt;
27//!
28//! let file = File::open("/etc/passwd").unwrap();
29//! let listener = UnixListener::bind("/tmp/test.sock").unwrap();
30//! let (stream, _) = listener.accept().unwrap();
31//! let mut stream = stream.with_fd();
32//! stream.write_with_fd(b"data", &[file.as_fd()]).unwrap();
33//! ```
34//!
35//! Process 2:
36//!
37//! ```no_run
38//! use std::{
39//!     fs::File,
40//!     io::Read,
41//!     os::unix::{io::FromRawFd, net::UnixStream},
42//! };
43//!
44//! use withfd::WithFdExt;
45//!
46//! let stream = UnixStream::connect("/tmp/test.sock").unwrap();
47//! let mut stream = stream.with_fd();
48//! let mut buf = [0u8; 4];
49//! stream.read_exact(&mut buf[..]).unwrap();
50//! let fd = stream.take_fds().next().unwrap();
51//! let mut file = File::from(fd);
52//! let mut buf = String::new();
53//! file.read_to_string(&mut buf).unwrap();
54//! println!("{}", buf);
55//! ```
56#![cfg_attr(docsrs, feature(doc_cfg))]
57
58use std::{
59    io::{IoSlice, IoSliceMut, Read, Write},
60    os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd},
61};
62
63use nix::sys::socket::ControlMessageOwned;
64
65/// Adapter for sending data with file descriptors.
66///
67/// You can create this by using the [`WithFdExt `] trait and calling the
68/// `with_fd` method on supported types.
69#[cfg_attr(feature = "async-io", pin_project::pin_project)]
70pub struct WithFd<T> {
71    #[cfg_attr(feature = "async-io", pin)]
72    inner: T,
73    fds:   Vec<OwnedFd>,
74    cmsg:  Vec<u8>,
75}
76
77pub trait WithFdExt: Sized {
78    fn with_fd(self) -> WithFd<Self>;
79}
80
81pub const SCM_MAX_FD: usize = 253;
82
83impl Read for WithFd<std::os::unix::net::UnixStream> {
84    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
85        self.read_with_fd(buf)
86    }
87}
88impl Write for WithFd<std::os::unix::net::UnixStream> {
89    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
90        self.inner.write(buf)
91    }
92
93    fn flush(&mut self) -> std::io::Result<()> {
94        self.inner.flush()
95    }
96
97    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
98        self.inner.write_all(buf)
99    }
100
101    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
102        self.inner.write_vectored(bufs)
103    }
104
105    fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
106        self.inner.write_fmt(fmt)
107    }
108}
109
110impl<T: AsRawFd> WithFd<T> {
111    fn write_with_fd_impl(fd: RawFd, buf: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
112        // Safety: BorrowedFd is repr(transparent) over RawFd
113        let fds = unsafe { std::slice::from_raw_parts(fds.as_ptr().cast::<RawFd>(), fds.len()) };
114        let cmsg = nix::sys::socket::ControlMessage::ScmRights(fds);
115        let sendmsg = nix::sys::socket::sendmsg::<()>(
116            fd,
117            &[IoSlice::new(buf)],
118            &[cmsg],
119            nix::sys::socket::MsgFlags::empty(),
120            None,
121        )?;
122        Ok(sendmsg)
123    }
124
125    fn raw_read_with_fd(
126        fd: RawFd,
127        cmsg: &mut Vec<u8>,
128        out_fds: &mut Vec<OwnedFd>,
129        buf: &mut [u8],
130    ) -> std::io::Result<usize> {
131        let mut buf = [IoSliceMut::new(buf)];
132        let recvmsg = nix::sys::socket::recvmsg::<()>(
133            fd,
134            &mut buf,
135            Some(cmsg),
136            nix::sys::socket::MsgFlags::empty(),
137        )?;
138        for cmsg in recvmsg.cmsgs()? {
139            if let ControlMessageOwned::ScmRights(fds) = cmsg {
140                out_fds.extend(fds.iter().map(|&fd| unsafe { OwnedFd::from_raw_fd(fd) }));
141            }
142        }
143        Ok(recvmsg.bytes)
144    }
145
146    fn read_with_fd(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
147        let fd = self.inner.as_raw_fd();
148        Self::raw_read_with_fd(fd, &mut self.cmsg, &mut self.fds, buf)
149    }
150
151    /// Returns an iterator over the file descriptors received.
152    /// Every file descriptor this iterator yields will be removed from the
153    /// internal buffer, and will not be returned again. Dropping the iterator
154    /// without exhausting it will leave the remaining file descriptors intact.
155    pub fn take_fds(&mut self) -> impl Iterator<Item = OwnedFd> + '_ {
156        struct Iter<'a>(&'a mut Vec<OwnedFd>);
157        impl Iterator for Iter<'_> {
158            type Item = OwnedFd;
159
160            fn next(&mut self) -> Option<Self::Item> {
161                self.0.pop()
162            }
163        }
164        Iter(&mut self.fds)
165    }
166}
167impl WithFd<std::os::unix::net::UnixStream> {
168    /// Write data, with additional pass file descriptors. For most of the unix
169    /// systems, file descriptors must be sent along with at least one byte
170    /// of data. This is why there is not a `write_fd` method.
171    pub fn write_with_fd(&mut self, buf: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
172        let fd = self.inner.as_raw_fd();
173        Self::write_with_fd_impl(fd, buf, fds)
174    }
175}
176
177impl WithFdExt for std::os::unix::net::UnixStream {
178    fn with_fd(self) -> WithFd<Self> {
179        self.into()
180    }
181}
182
183impl From<std::os::unix::net::UnixStream> for WithFd<std::os::unix::net::UnixStream> {
184    fn from(inner: std::os::unix::net::UnixStream) -> Self {
185        Self {
186            inner,
187            fds: Vec::new(),
188            cmsg: nix::cmsg_space!([RawFd; SCM_MAX_FD]),
189        }
190    }
191}
192
193#[cfg(test)]
194mod test {
195    use std::{
196        fs::File,
197        io::{Read, Seek, Write},
198        os::fd::AsFd,
199    };
200
201    use cstr::cstr;
202    #[cfg(target_os = "linux")]
203    use nix::sys::memfd::MemFdCreateFlag;
204
205    #[cfg(target_os = "linux")]
206    #[test]
207    fn test_send_fd() {
208        let (a, b) = std::os::unix::net::UnixStream::pair().unwrap();
209        let mut a = super::WithFd::from(a);
210        let mut b = super::WithFd::from(b);
211
212        let memfd =
213            nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
214        let mut memfd: File = memfd.into();
215        a.write_with_fd(b"hello", &[memfd.as_fd()]).unwrap();
216        let mut buf = [0u8; 5];
217        b.read_exact(&mut buf).unwrap();
218        assert_eq!(&buf[..], b"hello");
219        let fds = b.take_fds().collect::<Vec<_>>();
220        assert_eq!(fds.len(), 1);
221
222        let mut memfd2: File = fds.into_iter().next().unwrap().into();
223
224        memfd.write_all(b"Hello").unwrap();
225        drop(memfd);
226
227        memfd2.rewind().unwrap();
228        memfd2.read_exact(&mut buf).unwrap();
229        assert_eq!(&buf[..], b"Hello");
230    }
231
232    #[cfg(feature = "async-io")]
233    #[tokio::test]
234    async fn test_send_fd_async_async_io() {
235        use futures_util::io::{AsyncReadExt, AsyncWriteExt};
236        let (a, b) = async_io::Async::<std::os::unix::net::UnixStream>::pair().unwrap();
237        let a = super::WithFd::from(a);
238        let mut b = super::WithFd::from(b);
239
240        let memfd =
241            nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
242        let mut memfd: File = memfd.into();
243        tokio::spawn(async move {
244            memfd.write_all(b"Hello").unwrap();
245            a.write_with_fd(b"hello", &[memfd.as_fd()]).await.unwrap();
246            (&a).write_all(b"world").await.unwrap();
247            drop(memfd);
248        });
249        let mut buf = [0u8; 5];
250        b.read_exact(&mut buf).await.unwrap();
251        assert_eq!(&buf[..], b"hello");
252        let fds = b.take_fds().collect::<Vec<_>>();
253        assert_eq!(fds.len(), 1);
254        b.read_exact(&mut buf).await.unwrap();
255        assert_eq!(&buf[..], b"world");
256
257        let mut memfd2: File = fds.into_iter().next().unwrap().into();
258
259        memfd2.rewind().unwrap();
260        memfd2.read_exact(&mut buf).unwrap();
261        assert_eq!(&buf[..], b"Hello");
262    }
263
264    #[cfg(feature = "tokio")]
265    #[tokio::test]
266    async fn test_send_fd_async_tokio() {
267        use tokio::io::AsyncReadExt;
268        let (a, b) = tokio::net::UnixStream::pair().unwrap();
269        let mut a = super::WithFd::from(a);
270        let mut b = super::WithFd::from(b);
271
272        let memfd =
273            nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
274        let memfd = unsafe { OwnedFd::from_raw_fd(memfd) };
275        let mut memfd: File = memfd.into();
276        a.write_with_fd(b"hello", &[memfd.as_fd()]).await.unwrap();
277        let mut buf = [0u8; 5];
278        b.read_exact(&mut buf).await.unwrap();
279        assert_eq!(&buf[..], b"hello");
280        let read_handle = tokio::spawn(async move {
281            // Test that background read works
282            b.read_exact(&mut buf).await.unwrap();
283            (b, buf)
284        });
285
286        // Yield so the read has a chance to run
287        tokio::task::yield_now().await;
288
289        a.write_with_fd(b"world", &[]).await.unwrap();
290        let (mut b, mut buf) = read_handle.await.unwrap();
291        assert_eq!(&buf[..], b"world");
292        let fds = b.take_fds().collect::<Vec<_>>();
293        assert_eq!(fds.len(), 1);
294
295        let mut memfd2: File = fds.into_iter().next().unwrap().into();
296
297        memfd.write_all(b"Hello").unwrap();
298        drop(memfd);
299
300        memfd2.rewind().unwrap();
301        memfd2.read_exact(&mut buf).unwrap();
302        assert_eq!(&buf[..], b"Hello");
303    }
304}
305
306#[cfg(any(feature = "tokio", docsrs))]
307#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
308#[doc(hidden)]
309pub mod tokio {
310    use std::{
311        os::fd::{AsRawFd, BorrowedFd, RawFd},
312        pin::Pin,
313        task::ready,
314    };
315
316    use tokio::io::{AsyncRead, AsyncWrite, Interest};
317
318    use crate::WithFd;
319
320    impl AsyncRead for WithFd<tokio::net::UnixStream> {
321        fn poll_read(
322            self: std::pin::Pin<&mut Self>,
323            cx: &mut std::task::Context<'_>,
324            buf: &mut tokio::io::ReadBuf<'_>,
325        ) -> std::task::Poll<std::io::Result<()>> {
326            let unfilled = buf.initialize_unfilled();
327            let Self { inner, cmsg, fds } = self.get_mut();
328            let fd = inner.as_raw_fd();
329            loop {
330                ready!(inner.poll_read_ready(cx))?;
331                // Try reading, and clear the readiness state if we get WouldBlock.
332                match inner.try_io(Interest::READABLE, || {
333                    Self::raw_read_with_fd(fd, cmsg, fds, unfilled)
334                }) {
335                    Ok(bytes) => {
336                        buf.advance(bytes);
337                        return std::task::Poll::Ready(Ok(()))
338                    },
339                    // WouldBlock doesn't mean `try_io` would register us as a reader in the tokio
340                    // runtime, so we need to do one more loop and let `poll_read_ready` do it.
341                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
342                    e => return std::task::Poll::Ready(e.map(|_| ())),
343                }
344            }
345        }
346    }
347
348    impl AsyncWrite for WithFd<tokio::net::UnixStream> {
349        fn poll_write(
350            mut self: std::pin::Pin<&mut Self>,
351            cx: &mut std::task::Context<'_>,
352            buf: &[u8],
353        ) -> std::task::Poll<Result<usize, std::io::Error>> {
354            Pin::new(&mut self.inner).poll_write(cx, buf)
355        }
356
357        fn poll_flush(
358            mut self: std::pin::Pin<&mut Self>,
359            cx: &mut std::task::Context<'_>,
360        ) -> std::task::Poll<Result<(), std::io::Error>> {
361            Pin::new(&mut self.inner).poll_flush(cx)
362        }
363
364        fn poll_shutdown(
365            mut self: std::pin::Pin<&mut Self>,
366            cx: &mut std::task::Context<'_>,
367        ) -> std::task::Poll<Result<(), std::io::Error>> {
368            Pin::new(&mut self.inner).poll_shutdown(cx)
369        }
370
371        fn poll_write_vectored(
372            mut self: std::pin::Pin<&mut Self>,
373            cx: &mut std::task::Context<'_>,
374            bufs: &[std::io::IoSlice<'_>],
375        ) -> std::task::Poll<Result<usize, std::io::Error>> {
376            Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
377        }
378
379        fn is_write_vectored(&self) -> bool {
380            self.inner.is_write_vectored()
381        }
382    }
383
384    impl WithFd<tokio::net::UnixStream> {
385        /// Write data, with additional pass file descriptors. For most of the
386        /// unix systems, file descriptors must be sent along with at
387        /// least one byte of data. This is why there is not a
388        /// `write_fd` method.
389        pub async fn write_with_fd(
390            &mut self,
391            buf: &[u8],
392            fds: &[BorrowedFd<'_>],
393        ) -> std::io::Result<usize> {
394            let fd = self.inner.as_raw_fd();
395            loop {
396                self.inner.writable().await?;
397                match self.inner.try_io(Interest::WRITABLE, || {
398                    Self::write_with_fd_impl(fd, buf, fds)
399                }) {
400                    Ok(bytes) => break Ok(bytes),
401                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
402                    e => break Ok(e?),
403                }
404            }
405        }
406    }
407    impl From<tokio::net::UnixStream> for WithFd<tokio::net::UnixStream> {
408        fn from(inner: tokio::net::UnixStream) -> Self {
409            Self {
410                inner,
411                fds: Vec::new(),
412                cmsg: nix::cmsg_space!([RawFd; super::SCM_MAX_FD]),
413            }
414        }
415    }
416    impl super::WithFdExt for tokio::net::UnixStream {
417        fn with_fd(self) -> super::WithFd<Self> {
418            self.into()
419        }
420    }
421}
422
423#[cfg(any(feature = "async-io", docsrs))]
424#[cfg_attr(docsrs, doc(cfg(feature = "async-io")))]
425#[doc(hidden)]
426pub mod async_io {
427    use std::{os::fd::AsRawFd, pin::Pin, task::ready};
428
429    use async_io::Async;
430    use futures_io::{AsyncRead, AsyncWrite};
431
432    use crate::WithFd;
433
434    impl AsyncRead for WithFd<Async<std::os::unix::net::UnixStream>> {
435        fn poll_read(
436            self: Pin<&mut Self>,
437            cx: &mut std::task::Context<'_>,
438            buf: &mut [u8],
439        ) -> std::task::Poll<futures_io::Result<usize>> {
440            let this = self.project();
441            let fd = this.inner.as_raw_fd();
442            loop {
443                match Self::raw_read_with_fd(fd, this.cmsg, this.fds, buf) {
444                    Ok(bytes) => return std::task::Poll::Ready(Ok(bytes)),
445                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => (),
446                    e => return std::task::Poll::Ready(e),
447                }
448                ready!(this.inner.poll_readable(cx))?;
449            }
450        }
451    }
452
453    impl<T> AsyncWrite for &WithFd<Async<T>>
454    where
455        for<'a> &'a Async<T>: AsyncWrite,
456    {
457        fn poll_close(
458            self: Pin<&mut Self>,
459            cx: &mut std::task::Context<'_>,
460        ) -> std::task::Poll<futures_io::Result<()>> {
461            Pin::new(&mut &self.inner).poll_close(cx)
462        }
463
464        fn poll_flush(
465            self: Pin<&mut Self>,
466            cx: &mut std::task::Context<'_>,
467        ) -> std::task::Poll<futures_io::Result<()>> {
468            Pin::new(&mut &self.inner).poll_flush(cx)
469        }
470
471        fn poll_write(
472            self: Pin<&mut Self>,
473            cx: &mut std::task::Context<'_>,
474            buf: &[u8],
475        ) -> std::task::Poll<futures_io::Result<usize>> {
476            Pin::new(&mut &self.inner).poll_write(cx, buf)
477        }
478
479        fn poll_write_vectored(
480            self: Pin<&mut Self>,
481            cx: &mut std::task::Context<'_>,
482            bufs: &[futures_io::IoSlice<'_>],
483        ) -> std::task::Poll<futures_io::Result<usize>> {
484            Pin::new(&mut &self.inner).poll_write_vectored(cx, bufs)
485        }
486    }
487
488    impl<T> AsyncWrite for WithFd<Async<T>>
489    where
490        Async<T>: AsyncWrite,
491    {
492        fn poll_close(
493            self: Pin<&mut Self>,
494            cx: &mut std::task::Context<'_>,
495        ) -> std::task::Poll<futures_io::Result<()>> {
496            self.project().inner.poll_close(cx)
497        }
498
499        fn poll_flush(
500            self: Pin<&mut Self>,
501            cx: &mut std::task::Context<'_>,
502        ) -> std::task::Poll<futures_io::Result<()>> {
503            self.project().inner.poll_flush(cx)
504        }
505
506        fn poll_write(
507            self: Pin<&mut Self>,
508            cx: &mut std::task::Context<'_>,
509            buf: &[u8],
510        ) -> std::task::Poll<futures_io::Result<usize>> {
511            self.project().inner.poll_write(cx, buf)
512        }
513
514        fn poll_write_vectored(
515            self: Pin<&mut Self>,
516            cx: &mut std::task::Context<'_>,
517            bufs: &[futures_io::IoSlice<'_>],
518        ) -> std::task::Poll<futures_io::Result<usize>> {
519            self.project().inner.poll_write_vectored(cx, bufs)
520        }
521    }
522    impl WithFd<Async<std::os::unix::net::UnixStream>> {
523        /// Write data, with additional pass file descriptors. For most of the
524        /// unix systems, file descriptors must be sent along with at
525        /// least one byte of data. This is why there is not a
526        /// `write_fd` method.
527        pub async fn write_with_fd(
528            &self,
529            buf: &[u8],
530            fds: &[std::os::fd::BorrowedFd<'_>],
531        ) -> std::io::Result<usize> {
532            let fd = self.inner.as_raw_fd();
533            loop {
534                self.inner.writable().await?;
535                match Self::write_with_fd_impl(fd, buf, fds) {
536                    Ok(bytes) => break Ok(bytes),
537                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
538                    e => break Ok(e?),
539                }
540            }
541        }
542    }
543
544    impl From<Async<std::os::unix::net::UnixStream>> for WithFd<Async<std::os::unix::net::UnixStream>> {
545        fn from(inner: Async<std::os::unix::net::UnixStream>) -> Self {
546            Self {
547                inner,
548                fds: Vec::new(),
549                cmsg: nix::cmsg_space!([std::os::unix::io::RawFd; super::SCM_MAX_FD]),
550            }
551        }
552    }
553
554    impl super::WithFdExt for Async<std::os::unix::net::UnixStream> {
555        fn with_fd(self) -> super::WithFd<Self> {
556            self.into()
557        }
558    }
559}