makepad_platform/os/linux/
ipc.rs

1//! IPC ("inter-process communication") abstractions used on Linux.
2//!
3//! **NOTE**: the actual implementations may be portable to other OSes,
4//! e.g. "UNIX domain sockets" are definitely not Linux-only, but there
5//! may be other reasons to only *need* them on Linux such as macOS
6
7use std::{
8    io,
9    marker::PhantomData,
10    os::{
11        fd::{AsFd, BorrowedFd, OwnedFd},
12        unix::net::UnixStream,
13    },
14};
15
16/// One endpoint of a bi-directional inter-process communication channel,
17/// capable of sending/receiving both raw bytes and UNIX file descriptors,
18/// encoded/decoded from/to the `TX`/`RX` types, with an ordering guarantee
19/// (messages will be received in the same order that they were sent).
20//
21// FIXME(eddyb) should this be moved to a `mod channel` and renamed to e.g.
22// `SenderReceiver`? (and mimicking `std::sync::mpsc` for `Sender`/`Receiver`)
23pub struct Channel<TX, RX> {
24    stream: UnixStream,
25    _marker: PhantomData<(fn(TX) -> RX, fn(RX) -> TX)>,
26}
27
28pub fn channel<TX, RX>() -> io::Result<(Channel<TX, RX>, Channel<RX, TX>)> {
29    let (a, b) = UnixStream::pair()?;
30    Ok((
31        Channel {
32            stream: a,
33            _marker: PhantomData,
34        },
35        Channel {
36            stream: b,
37            _marker: PhantomData,
38        },
39    ))
40}
41
42impl<TX, RX> Clone for Channel<TX, RX> {
43    fn clone(&self) -> Self {
44        Self {
45            stream: self.stream.try_clone().unwrap(),
46            _marker: PhantomData,
47        }
48    }
49}
50
51// FIXME(eddyb) the `cfg(use_unstable_unix_socket_ancillary_data_2021)`
52// implementation works on (and has been tested for) nightlies ranging
53// from early 2021 to late 2023 (roughly matching 1.51 - 1.73 relases),
54// but is provided here mostly for pedagogical reasons, as it's quite
55// likely stabilization (in 2024 or later) will be blocked a redesign
56// of the API, as per https://github.com/rust-lang/rust/issues/76915
57// comments (also, note that this cfg has no exposed way of turning
58// it on, short of passing it to `rustc` via `RUSTFLAGS=--cfg=...`).
59#[cfg(use_unstable_unix_socket_ancillary_data_2021)]
60mod sys {
61    use super::*;
62    use std::os::fd::FromRawFd;
63    use std::os::unix::net::{AncillaryData, SocketAncillary};
64
65    pub(super) fn stream_sendmsg<const FD_LEN: usize>(
66        stream: &UnixStream,
67        bytes: io::IoSlice<'_>,
68        fds: &[BorrowedFd<'_>; FD_LEN],
69    ) -> io::Result<()> {
70        let mut ancillary_buffer = [0; 64];
71        let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
72        if !ancillary.add_fds(unsafe { &*(fds as *const [BorrowedFd<'_>] as *const [i32]) }) {
73            return Err(io::Error::other(format!(
74                "failed to send {FD_LEN} file descriptors: \
75                 the resulting cmsg doesn't fit in {} bytes",
76                ancillary.capacity()
77            )));
78        }
79        let written_len = stream.send_vectored_with_ancillary(&[bytes], &mut ancillary)?;
80        if written_len != bytes.len() {
81            return Err(io::Error::other(format!(
82                "partial write (only {written_len} out of {})",
83                bytes.len()
84            )));
85        }
86        Ok(())
87    }
88
89    pub(super) fn stream_recvmsg<const FD_LEN: usize>(
90        stream: &UnixStream,
91        bytes: io::IoSliceMut<'_>,
92    ) -> io::Result<[OwnedFd; FD_LEN]> {
93        let mut ancillary_buffer = [0; 64];
94        let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
95        let expected_len = bytes.len();
96        let read_len = stream.recv_vectored_with_ancillary(&mut [bytes], &mut ancillary)?;
97        let partial_read = read_len != expected_len;
98        let (anciliary_truncated, anciliary_capacity) =
99            (ancillary.truncated(), ancillary.capacity());
100
101        // HACK(eddyb) this is painfully stateful so that it has a chance to
102        // `close` *all* unwanted `OwnedFd`s, to avoid keeping *any* alive
103        // (even without a malicious sender, any mistake could easily end up
104        // leaking hundreds of file descriptors, and with e.g. DMA-BUF they'd
105        // easily keep alive buffers totalling more than most GPUs have VRAM).
106        let mut errors = vec![];
107        let mut accepted_fds = [(); FD_LEN].map(|()| None);
108        let mut accepted_fd_count = 0;
109        for cmsg in ancillary.messages() {
110            match cmsg {
111                Err(err) => errors.push(format!("{err:?}")),
112                Ok(AncillaryData::ScmRights(raw_fds)) => {
113                    let is_first_scm_rights = accepted_fd_count == 0;
114                    for raw_fd in raw_fds {
115                        if raw_fd == -1 {
116                            errors.push("invalid fd (-1) received".into());
117                            continue;
118                        }
119                        // Using `OwnedFd` ensure all unwanted file descriptors
120                        // are closed (see larger comment above for why).
121                        let fd = unsafe { OwnedFd::from_raw_fd(raw_fd) };
122                        if is_first_scm_rights {
123                            // NOTE(eddyb) too few/many fds are handled later.
124                            let i = accepted_fd_count;
125                            accepted_fd_count += 1;
126                            if let Some(slot) = accepted_fds.get_mut(i) {
127                                *slot = Some(fd);
128                            }
129                        }
130                    }
131                    if !is_first_scm_rights {
132                        errors.push("received more than one SCM_RIGHTS cmsg".into());
133                    }
134                }
135                Ok(AncillaryData::ScmCredentials(_)) => {
136                    errors.push("received unexpected SCM_CREDS-like cmsg".into());
137                }
138            }
139        }
140        if accepted_fd_count != FD_LEN {
141            errors.push(format!(
142                "wrong number of received fds: expected {FD_LEN}, got {accepted_fd_count}"
143            ))
144        }
145
146        if partial_read {
147            return Err(io::Error::other(format!(
148                "partial read: only {read_len} out of {expected_len}"
149            )));
150        }
151        if anciliary_truncated {
152            return Err(io::Error::other(format!(
153                "truncated anciliary buffer: received cmsg doesn't fit in {anciliary_capacity} bytes"
154            )));
155        }
156
157        if errors.is_empty() {
158            Ok(accepted_fds.map(Option::unwrap))
159        } else {
160            Err(io::Error::other(if errors.len() == 1 {
161                errors.pop().unwrap()
162            } else {
163                format!("errors during receiving:\n  {}", errors.join("\n  "))
164            }))
165        }
166    }
167}
168#[cfg(not(use_unstable_unix_socket_ancillary_data_2021))]
169mod sys {
170    #![allow(non_camel_case_types)]
171
172    // HACK(eddyb) `io::Error::other` stabilization is too recent.
173    fn io_error_other(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
174        io::Error::new(io::ErrorKind::Other, error)
175    }
176
177    use super::*;
178    use std::{
179        ffi::{c_int, c_void},
180        ptr,
181    };
182
183    type socklen_t = u32;
184
185    #[repr(C)]
186    struct msghdr<IOV> {
187        msg_name: *mut c_void,
188        msg_namelen: socklen_t,
189        msg_iov: *mut IOV,
190        msg_iovlen: usize,
191        msg_control: *mut c_void,
192        msg_controllen: usize,
193        msg_flags: c_int,
194    }
195
196    const SOL_SOCKET: c_int = 1;
197    const SCM_RIGHTS: c_int = 1;
198
199    #[repr(C)]
200    struct cmsghdr {
201        cmsg_len: usize,
202        cmsg_level: c_int,
203        cmsg_type: c_int,
204    }
205    const _: () = assert!(std::mem::size_of::<cmsghdr>() % std::mem::size_of::<usize>() == 0);
206
207    extern "C" {
208        fn sendmsg(
209            sockfd: BorrowedFd<'_>,
210            msg: *const msghdr<io::IoSlice<'_>>,
211            flags: c_int,
212        ) -> isize;
213        fn recvmsg(
214            sockfd: BorrowedFd<'_>,
215            msg: *mut msghdr<io::IoSliceMut<'_>>,
216            flags: c_int,
217        ) -> isize;
218    }
219
220    #[repr(C)]
221    struct CMsgBuf<FD, const FD_LEN: usize> {
222        header: cmsghdr,
223        fds: [FD; FD_LEN],
224    }
225
226    pub(super) fn stream_sendmsg<const FD_LEN: usize>(
227        stream: &UnixStream,
228        mut bytes: io::IoSlice<'_>,
229        fds: &[BorrowedFd<'_>; FD_LEN],
230    ) -> io::Result<()> {
231        let mut cmsg_buf = CMsgBuf {
232            header: cmsghdr {
233                cmsg_len: std::mem::size_of::<cmsghdr>() + FD_LEN * 4,
234                cmsg_level: SOL_SOCKET,
235                cmsg_type: SCM_RIGHTS,
236            },
237            fds: *fds,
238        };
239
240        let written_len = unsafe {
241            sendmsg(
242                stream.as_fd(),
243                &msghdr {
244                    msg_name: ptr::null_mut(),
245                    msg_namelen: 0,
246                    msg_iov: &mut bytes,
247                    msg_iovlen: 1,
248                    msg_control: &mut cmsg_buf as *mut _ as *mut _,
249                    msg_controllen: std::mem::size_of_val(&cmsg_buf),
250                    msg_flags: 0,
251                },
252                0,
253            )
254        };
255        if written_len == -1 {
256            return Err(io::Error::last_os_error());
257        }
258        if written_len as usize != bytes.len() {
259            return Err(io_error_other(format!(
260                "partial write (only {written_len} out of {})",
261                bytes.len()
262            )));
263        }
264        Ok(())
265    }
266
267    pub(super) fn stream_recvmsg<const FD_LEN: usize>(
268        stream: &UnixStream,
269        mut bytes: io::IoSliceMut<'_>,
270    ) -> io::Result<[OwnedFd; FD_LEN]> {
271        let expected_len = bytes.len();
272
273        let mut cmsg_buf = std::mem::MaybeUninit::<CMsgBuf<Option<OwnedFd>, FD_LEN>>::zeroed();
274        let expected_cmsg_len = std::mem::size_of::<cmsghdr>() + FD_LEN * 4;
275        let expected_msg_controllen = std::mem::size_of_val(&cmsg_buf);
276
277        let mut msg = msghdr {
278            msg_name: ptr::null_mut(),
279            msg_namelen: 0,
280            msg_iov: &mut bytes,
281            msg_iovlen: 1,
282            msg_control: &mut cmsg_buf as *mut _ as *mut _,
283            msg_controllen: expected_msg_controllen,
284            msg_flags: 0,
285        };
286
287        let read_len = unsafe { recvmsg(stream.as_fd(), &mut msg, 0) };
288        if read_len == -1 {
289            return Err(io::Error::last_os_error());
290        }
291
292        // FIXME(eddyb) all of these errors should close fds to prevent fd DOS,
293        // but for now this is not particularly a notable surface of attack.
294
295        if read_len as usize != expected_len {
296            return Err(io_error_other(format!(
297                "partial read: only {read_len} out of {expected_len}"
298            )));
299        }
300
301        if msg.msg_controllen != expected_msg_controllen {
302            return Err(io_error_other(format!(
303                "recvmsg msg_controllen mismatch: got {}, expected {expected_msg_controllen}",
304                msg.msg_controllen,
305            )));
306        }
307
308        let cmsg = unsafe { cmsg_buf.assume_init() };
309        if cmsg.header.cmsg_len != expected_cmsg_len {
310            return Err(io_error_other(format!(
311                "recvmsg cmsg_len mismatch: got {}, expected {expected_cmsg_len}",
312                cmsg.header.cmsg_len
313            )));
314        }
315
316        if (cmsg.header.cmsg_level, cmsg.header.cmsg_type) != (SOL_SOCKET, SCM_RIGHTS) {
317            return Err(io_error_other(format!("unsupported non-SCM_RIGHTS CMSG")));
318        }
319
320        if cmsg.fds.iter().any(|fd| fd.is_none()) {
321            return Err(io_error_other(format!("recvmsg got invalid (-1) fds")));
322        }
323
324        Ok(cmsg.fds.map(Option::unwrap))
325    }
326}
327
328impl<TX, RX> Channel<TX, RX> {
329    pub fn send<const TX_BYTE_LEN: usize, const TX_FD_LEN: usize>(&self, msg: TX) -> io::Result<()>
330    where
331        TX: FixedSizeEncoding<TX_BYTE_LEN, TX_FD_LEN>,
332    {
333        assert_ne!(
334            TX_FD_LEN,
335            0,
336            "Channel<{}, _> unsupported (lacks file descriptors)",
337            std::any::type_name::<TX>()
338        );
339
340        let (bytes, fds) = msg.encode();
341        sys::stream_sendmsg(&self.stream, io::IoSlice::new(&bytes), &fds)
342    }
343
344    pub fn recv<const RX_BYTE_LEN: usize, const RX_FD_LEN: usize>(&self) -> io::Result<RX>
345    where
346        RX: FixedSizeEncoding<RX_BYTE_LEN, RX_FD_LEN>,
347    {
348        assert_ne!(
349            RX_FD_LEN,
350            0,
351            "Channel<_, {}> unsupported (lacks file descriptors)",
352            std::any::type_name::<TX>()
353        );
354
355        // FIXME(eddyb) this should use `io::BorrowedBuf` when that's stabilized.
356        let mut bytes = [0; RX_BYTE_LEN];
357        let fds = sys::stream_recvmsg(&self.stream, io::IoSliceMut::new(&mut bytes))?;
358        Ok(RX::decode(bytes, fds))
359    }
360
361    /// Enable child process inheritance (see [`InheritableChannel`]),
362    /// i.e. remove the `CLOEXEC` flag (via `dup`, not `fcntl(F_{SET,GET}FD)`,
363    /// due to the latter's misdesign as read/write instead of `fetch_{and,or}`,
364    /// so they invite race conditions and should be deprecated and never used).
365    pub fn into_child_process_inheritable(self) -> io::Result<InheritableChannel<TX, RX>> {
366        extern "C" {
367            fn dup(fd: BorrowedFd<'_>) -> Option<OwnedFd>;
368        }
369        Ok(InheritableChannel(Self {
370            stream: unsafe { dup(self.stream.as_fd()) }
371                .ok_or_else(|| io::Error::last_os_error())?
372                .into(),
373            _marker: PhantomData,
374        }))
375    }
376}
377
378/// A `Channel<TX, RX>` whose internal (UNIX domain socket) file descriptor will
379/// persist in all child proceses (except for those which explicitly close it),
380/// and which only provides conversions to/from file descriptors, and a way to
381/// disable inheritance (i.e. re-enabling `CLOEXEC` semantics on it).
382pub struct InheritableChannel<TX, RX>(Channel<TX, RX>);
383
384impl<TX, RX> AsFd for InheritableChannel<TX, RX> {
385    fn as_fd(&self) -> BorrowedFd<'_> {
386        self.0.stream.as_fd()
387    }
388}
389
390impl<TX, RX> From<OwnedFd> for InheritableChannel<TX, RX> {
391    fn from(fd: OwnedFd) -> Self {
392        Self(Channel {
393            stream: UnixStream::from(fd),
394            _marker: PhantomData,
395        })
396    }
397}
398
399impl<TX, RX> InheritableChannel<TX, RX> {
400    /// Disable child process inheritance, i.e. re-add the `CLOEXEC` flag
401    /// (via `try_clone_to_owned` which uses `fcntl(F_DUPFD_CLOEXEC)`).
402    pub fn into_uninheritable(self) -> io::Result<Channel<TX, RX>> {
403        let Self(mut channel) = self;
404        channel.stream = channel.stream.as_fd().try_clone_to_owned()?.into();
405        Ok(channel)
406    }
407}
408
409/// Type with no values to make it impossible to send on a channel endpoint,
410/// or receive on its opposite counterpart, if that direction is unused.
411pub enum Never {}
412
413/// Encoding/decoding functionality that relies on each message being
414/// encoded to a constant (and small) "packet" size, allowing the use
415/// of 1:1 `sendmsg` and `recvmsg` calls, i.e. removing the need for
416/// any kind of "packet framing" that a `SOCK_STREAM` needs to soundly
417/// handle receiving a message's fds through multiple `recvmsg` calls.
418//
419// HACK(eddyb) using const generics instead of associated consts
420// only to be able to use the compile-time constants in array types.
421pub trait FixedSizeEncoding<const BYTE_LEN: usize, const FD_LEN: usize> {
422    // HACK(eddyb) avoids repeating the value inside an `impl`.
423    const BYTE_LEN: usize = BYTE_LEN;
424    const FD_LEN: usize = FD_LEN;
425
426    fn encode(&self) -> ([u8; BYTE_LEN], [BorrowedFd<'_>; FD_LEN]);
427    fn decode(bytes: [u8; BYTE_LEN], fds: [OwnedFd; FD_LEN]) -> Self;
428}
429
430// HACK(eddyb) simple `(OnlyBytes, OnlyFds)` to make it easier for const generics.
431impl<
432        const BYTE_LEN: usize,
433        const FD_LEN: usize,
434        A: FixedSizeEncoding<BYTE_LEN, 0>,
435        B: FixedSizeEncoding<0, FD_LEN>,
436    > FixedSizeEncoding<BYTE_LEN, FD_LEN> for (A, B)
437{
438    fn encode(&self) -> ([u8; BYTE_LEN], [BorrowedFd<'_>; FD_LEN]) {
439        let ((bytes, []), ([], fds)) = (self.0.encode(), self.1.encode());
440        (bytes, fds)
441    }
442    fn decode(bytes: [u8; BYTE_LEN], fds: [OwnedFd; FD_LEN]) -> Self {
443        (A::decode(bytes, []), B::decode([], fds))
444    }
445}
446
447macro_rules! fixed_size_le_prim_impls {
448    ($($ty:ident)*) => {
449        $(impl FixedSizeEncoding<{(Self::BITS / 8) as usize}, 0> for $ty {
450            fn encode(&self) -> ([u8; Self::BYTE_LEN], [BorrowedFd<'_>; 0]) {
451                (self.to_le_bytes(), [])
452            }
453            fn decode(bytes: [u8; Self::BYTE_LEN], []: [OwnedFd; 0]) -> Self {
454                Self::from_le_bytes(bytes)
455            }
456        })*
457    }
458}
459fixed_size_le_prim_impls!(u16 u32 u64 u128);
460
461impl FixedSizeEncoding<0, 1> for OwnedFd {
462    fn encode(&self) -> ([u8; 0], [BorrowedFd<'_>; 1]) {
463        ([], [self.as_fd()])
464    }
465    fn decode([]: [u8; 0], [fd]: [OwnedFd; 1]) -> Self {
466        fd
467    }
468}