runa_io/
lib.rs

1#![feature(type_alias_impl_trait)]
2use std::{
3    io::Result,
4    mem::MaybeUninit,
5    os::{
6        fd::FromRawFd,
7        unix::{
8            io::{AsRawFd, OwnedFd, RawFd},
9            net::UnixStream as StdUnixStream,
10        },
11    },
12    pin::Pin,
13    rc::Rc,
14    task::{ready, Context, Poll},
15};
16
17pub mod buf;
18pub mod utils;
19
20pub use buf::*;
21use bytes::{Buf, BufMut, BytesMut};
22use pin_project_lite::pin_project;
23use runa_io_traits::OwnedFds as _;
24
25pub mod traits {
26    pub use runa_io_traits::*;
27}
28
29/// Maximum number of file descriptors that can be sent in a write by the
30/// wayland protocol. As defined in libwayland.
31pub const MAX_FDS_OUT: usize = 28;
32
33pub const SCM_MAX_FD: usize = 253;
34
35#[derive(Debug)]
36pub struct ReadWithFd {
37    inner: Rc<async_io::Async<OwnedFd>>,
38
39    /// Temporary buffer used for recvmsg.
40    buf: Vec<u8>,
41}
42
43#[derive(Debug)]
44pub struct WriteWithFd {
45    inner: Rc<async_io::Async<OwnedFd>>,
46}
47
48#[tracing::instrument(level = "debug", ret)]
49pub fn split_unixstream(stream: StdUnixStream) -> Result<(ReadWithFd, WriteWithFd)> {
50    let raw_fd = Rc::new(async_io::Async::new(stream.into())?);
51    Ok((
52        ReadWithFd {
53            inner: raw_fd.clone(),
54            buf:   nix::cmsg_space!([RawFd; SCM_MAX_FD]),
55        },
56        WriteWithFd { inner: raw_fd },
57    ))
58}
59
60impl traits::AsyncWriteWithFd for WriteWithFd {
61    /// Writes the given buffer and file descriptors to a unix stream. `buf`
62    /// must contain at least one byte of data. This function should not be
63    /// called concurrently from different tasks. Otherwise you risk
64    /// interleaving data, as well as causing tasks to wake each other up and
65    /// eatting CPU.
66    #[inline]
67    fn poll_write_with_fds<Fds: traits::OwnedFds>(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        buf: &[u8],
71        fds: &mut Fds,
72    ) -> Poll<Result<usize>> {
73        use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags};
74
75        ready!(self.inner.poll_writable(cx)?);
76        let fd = self.inner.as_raw_fd();
77        let mut tmp_fds = OwnedFds::<SCM_MAX_FD>::new();
78        fds.take(&mut tmp_fds);
79
80        match sendmsg::<()>(
81            fd,
82            &[std::io::IoSlice::new(buf)],
83            &[ControlMessage::ScmRights(unsafe { tmp_fds.as_raw_fds() })],
84            MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL,
85            None,
86        ) {
87            Err(nix::errno::Errno::EWOULDBLOCK) => Poll::Pending,
88            Err(e) => Poll::Ready(Err(e.into())),
89            Ok(n) => Poll::Ready(Ok(n)),
90        }
91    }
92}
93
94// Copying some code from nix, to avoid allocation
95
96unsafe fn pack_mhdr_to_receive<'outer, 'inner, I, S>(
97    iov: I,
98    cmsg_buffer: &mut Option<&mut Vec<u8>>,
99    address: *mut S,
100) -> (usize, libc::msghdr)
101where
102    I: AsRef<[std::io::IoSliceMut<'inner>]> + 'outer,
103    S: nix::sys::socket::SockaddrLike + 'outer,
104{
105    let (msg_control, msg_controllen) = cmsg_buffer
106        .as_mut()
107        .map(|v| (v.as_mut_ptr(), v.capacity()))
108        .unwrap_or((std::ptr::null_mut(), 0));
109
110    let mhdr = {
111        // Musl's msghdr has private fields, so this is the only way to
112        // initialize it.
113        let mut mhdr = std::mem::MaybeUninit::<libc::msghdr>::zeroed();
114        let p = mhdr.as_mut_ptr();
115        (*p).msg_name = (*address).as_mut_ptr() as *mut libc::c_void;
116        (*p).msg_namelen = S::size();
117        (*p).msg_iov = iov.as_ref().as_ptr() as *mut libc::iovec;
118        (*p).msg_iovlen = iov.as_ref().len() as _;
119        (*p).msg_control = msg_control as *mut libc::c_void;
120        (*p).msg_controllen = msg_controllen as _;
121        (*p).msg_flags = 0;
122        mhdr.assume_init()
123    };
124
125    (msg_controllen, mhdr)
126}
127
128#[derive(Clone, Copy, Debug, Eq, PartialEq)]
129pub struct RecvMsg<'a, S> {
130    bytes:   usize,
131    cmsghdr: Option<&'a libc::cmsghdr>,
132    address: Option<S>,
133    flags:   nix::sys::socket::MsgFlags,
134    mhdr:    libc::msghdr,
135}
136impl<'a, S> RecvMsg<'a, S> {
137    /// Iterate over the valid control messages pointed to by this
138    /// msghdr.
139    pub fn scm_rights(&self) -> ScmRightsIterator {
140        ScmRightsIterator {
141            cmsghdr: self.cmsghdr,
142            mhdr:    &self.mhdr,
143        }
144    }
145}
146
147#[derive(Clone, Copy, Debug, Eq, PartialEq)]
148pub struct ScmRightsIterator<'a> {
149    /// Control message buffer to decode from. Must adhere to cmsg alignment.
150    cmsghdr: Option<&'a libc::cmsghdr>,
151    mhdr:    &'a libc::msghdr,
152}
153
154pub struct FdIter<'a> {
155    cmsghdr: &'a libc::cmsghdr,
156    idx:     usize,
157}
158
159impl<'a> Iterator for ScmRightsIterator<'a> {
160    type Item = FdIter<'a>;
161
162    fn next(&mut self) -> Option<Self::Item> {
163        loop {
164            match self.cmsghdr {
165                None => break None, // No more messages
166                Some(hdr) => {
167                    // Get the data.
168                    // Safe if cmsghdr points to valid data returned by recvmsg(2)
169                    let ret = FdIter {
170                        cmsghdr: hdr,
171                        idx:     0,
172                    };
173                    self.cmsghdr = unsafe {
174                        let p = libc::CMSG_NXTHDR(self.mhdr as *const _, hdr as *const _);
175                        p.as_ref()
176                    };
177                    if hdr.cmsg_type != libc::SCM_RIGHTS || hdr.cmsg_level != libc::SOL_SOCKET {
178                        continue
179                    }
180                    break Some(ret)
181                },
182            }
183        }
184    }
185}
186
187impl Iterator for FdIter<'_> {
188    type Item = RawFd;
189
190    fn next(&mut self) -> Option<Self::Item> {
191        let p = unsafe { libc::CMSG_DATA(self.cmsghdr as *const _) };
192        let data_len = self.cmsghdr as *const _ as usize + self.cmsghdr.cmsg_len - p as usize;
193        let nfds = data_len / std::mem::size_of::<RawFd>();
194        let fds = unsafe { std::slice::from_raw_parts(p as *const RawFd, nfds) };
195        let ret = fds.get(self.idx).copied();
196        self.idx += 1;
197        ret
198    }
199}
200
201unsafe fn read_mhdr<'b, S>(
202    mhdr: libc::msghdr,
203    r: isize,
204    msg_controllen: usize,
205    address: S,
206    cmsg_buffer: &mut Option<&'b mut Vec<u8>>,
207) -> RecvMsg<'b, S>
208where
209    S: nix::sys::socket::SockaddrLike,
210{
211    let cmsghdr = {
212        if mhdr.msg_controllen > 0 {
213            // got control message(s)
214            cmsg_buffer.as_mut().unwrap().set_len(mhdr.msg_controllen);
215            debug_assert!(!mhdr.msg_control.is_null());
216            debug_assert!(msg_controllen >= mhdr.msg_controllen);
217            libc::CMSG_FIRSTHDR(&mhdr as *const libc::msghdr)
218        } else {
219            std::ptr::null()
220        }
221        .as_ref()
222    };
223
224    RecvMsg {
225        bytes: r as usize,
226        cmsghdr,
227        address: Some(address),
228        flags: nix::sys::socket::MsgFlags::from_bits_truncate(mhdr.msg_flags),
229        mhdr,
230    }
231}
232
233pub fn recvmsg<'a, 'outer, 'inner, S>(
234    fd: RawFd,
235    iov: &'outer mut [std::io::IoSliceMut<'inner>],
236    mut cmsg_buffer: Option<&'a mut Vec<u8>>,
237    flags: nix::sys::socket::MsgFlags,
238) -> std::result::Result<RecvMsg<'a, S>, nix::Error>
239where
240    S: nix::sys::socket::SockaddrLike + 'a,
241{
242    let mut address = std::mem::MaybeUninit::uninit();
243
244    let (msg_controllen, mut mhdr) =
245        unsafe { pack_mhdr_to_receive::<_, S>(iov, &mut cmsg_buffer, address.as_mut_ptr()) };
246
247    let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };
248
249    let r = nix::errno::Errno::result(ret)?;
250
251    Ok(unsafe {
252        read_mhdr(
253            mhdr,
254            r,
255            msg_controllen,
256            address.assume_init(),
257            &mut cmsg_buffer,
258        )
259    })
260}
261
262impl traits::AsyncReadWithFd for ReadWithFd {
263    /// This implementation will close extra file descriptors if fd_limit is
264    /// reached.
265    fn poll_read_with_fds<Fds: traits::OwnedFds>(
266        mut self: Pin<&mut Self>,
267        cx: &mut Context<'_>,
268        buf: &mut [u8],
269        fds: &mut Fds,
270    ) -> Poll<Result<usize>> {
271        use nix::sys::socket::MsgFlags;
272        ready!(self.inner.poll_readable(cx)?);
273        let fd = self.inner.as_raw_fd();
274
275        match recvmsg::<()>(
276            fd,
277            &mut [std::io::IoSliceMut::new(buf)],
278            Some(&mut self.buf),
279            MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL | MsgFlags::MSG_CMSG_CLOEXEC,
280        ) {
281            Err(nix::errno::Errno::EWOULDBLOCK) => Poll::Pending,
282            Err(e) => Poll::Ready(Err(e.into())),
283            Ok(msg) => {
284                let ifds = msg
285                    .scm_rights()
286                    .flatten()
287                    // Safety: we just received those file descriptors so we know
288                    // they are valid and not shared.
289                    .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) });
290                fds.extend(ifds);
291                Poll::Ready(Ok(msg.bytes))
292            },
293        }
294    }
295}
296
297pub struct OwnedFds<const N: usize> {
298    fds: [MaybeUninit<OwnedFd>; N],
299    len: usize,
300}
301
302impl<const N: usize> Default for OwnedFds<N> {
303    fn default() -> Self {
304        const UNINIT: MaybeUninit<OwnedFd> = MaybeUninit::uninit();
305        Self {
306            fds: [UNINIT; N],
307            len: 0,
308        }
309    }
310}
311
312impl<const N: usize> std::fmt::Debug for OwnedFds<N> {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        let mut debug_list = f.debug_list();
315        for fd in self.fds[..self.len].iter() {
316            // Safety: fds[..self.len] are initialized.
317            let fd = unsafe { fd.assume_init_ref() };
318            debug_list.entry(fd);
319        }
320        debug_list.finish()
321    }
322}
323
324impl<const N: usize> OwnedFds<N> {
325    pub fn new() -> Self {
326        Self::default()
327    }
328
329    pub fn as_ptr(&self) -> *const OwnedFd {
330        self.fds.as_ptr() as *const _
331    }
332
333    /// Return a slice of raw file descriptors.
334    ///
335    /// # Safety
336    ///
337    /// the caller must ensure to not close the file descriptors.
338    pub unsafe fn as_raw_fds(&self) -> &[RawFd] {
339        unsafe { std::slice::from_raw_parts(self.as_ptr() as *const _, self.len) }
340    }
341}
342
343impl<const N: usize> traits::OwnedFds for OwnedFds<N> {
344    #[inline]
345    fn len(&self) -> usize {
346        self.len
347    }
348
349    #[inline]
350    fn capacity(&self) -> Option<usize> {
351        Some(N)
352    }
353
354    fn take<T: Extend<OwnedFd>>(&mut self, fds: &mut T) {
355        for fd in self.fds[..self.len].iter_mut() {
356            fds.extend(Some(unsafe { fd.assume_init_read() }));
357        }
358        self.len = 0;
359    }
360}
361
362impl<const N: usize> Extend<OwnedFd> for OwnedFds<N> {
363    fn extend<T: IntoIterator<Item = OwnedFd>>(&mut self, fds: T) {
364        for fd in fds {
365            if self.len < N {
366                self.fds[self.len] = MaybeUninit::new(fd);
367                self.len += 1;
368            } else {
369                drop(fd);
370            }
371        }
372    }
373}
374
375impl<const N: usize> Drop for OwnedFds<N> {
376    fn drop(&mut self) {
377        for fd in self.fds[..self.len].iter_mut() {
378            unsafe { fd.assume_init_drop() };
379        }
380        self.len = 0;
381    }
382}
383
384pin_project! {
385#[derive(Debug)]
386pub struct Connection<C> {
387    #[pin]
388    conn:     C,
389    buf:      BytesMut,
390    fds:      OwnedFds<SCM_MAX_FD>,
391    capacity: usize,
392}
393}
394
395impl<C> Connection<C> {
396    pub fn new(conn: C, capacity: usize) -> Self {
397        Connection {
398            conn,
399            capacity,
400            fds: OwnedFds::new(),
401            buf: BytesMut::with_capacity(capacity),
402        }
403    }
404}
405
406impl<C: traits::AsyncWriteWithFd> traits::WriteMessage for Connection<C> {
407    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
408        // Flush if we are:
409        //   1. over capacity or
410        //   2. not having enough space for MAX_FDS_OUT file descriptors
411        if self.buf.len() > self.capacity || self.fds.len() + MAX_FDS_OUT > SCM_MAX_FD {
412            ready!(self.poll_flush(cx))?;
413        }
414        Poll::Ready(Ok(()))
415    }
416
417    fn start_send<M: traits::ser::Serialize + std::fmt::Debug>(
418        self: Pin<&mut Self>,
419        object_id: u32,
420        msg: M,
421    ) {
422        let this = self.project();
423        assert!(msg.nfds() as usize + this.fds.len() <= SCM_MAX_FD);
424        this.buf.put_u32_ne(object_id);
425        this.buf
426            .reserve((msg.len() as usize).saturating_sub(this.buf.remaining_mut()));
427        msg.serialize(this.buf, this.fds);
428    }
429
430    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
431        let mut this = self.project();
432        while !this.buf.is_empty() {
433            let written =
434                ready!(this
435                    .conn
436                    .as_mut()
437                    .poll_write_with_fds(cx, &*this.buf, &mut *this.fds))?;
438            if written == 0 {
439                return Poll::Ready(Err(std::io::Error::new(
440                    std::io::ErrorKind::WriteZero,
441                    "written 0 bytes",
442                )))
443            }
444            this.buf.advance(written);
445        }
446        Poll::Ready(Ok(()))
447    }
448}