Skip to main content

unix_ancillary/
ancillary.rs

1use std::marker::PhantomData;
2use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
3use std::{fmt, mem};
4
5/// Error returned when the ancillary buffer is too small.
6#[derive(Debug, Clone)]
7pub struct AncillaryError;
8
9impl fmt::Display for AncillaryError {
10    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11        write!(f, "ancillary buffer too small")
12    }
13}
14
15impl std::error::Error for AncillaryError {}
16
17/// Received ancillary data from a Unix socket.
18pub enum AncillaryData<'a> {
19    /// File descriptors received via `SCM_RIGHTS`.
20    ScmRights(ScmRights<'a>),
21}
22
23/// Iterator over file descriptors received via `SCM_RIGHTS`.
24///
25/// Each yielded `OwnedFd` takes ownership of one received descriptor and
26/// closes it on drop.
27///
28/// # Important
29///
30/// Iterate this exactly once. Iterating the same `Messages`/`ScmRights` view
31/// twice would manufacture two `OwnedFd`s for the same raw fd, leading to a
32/// double-close.
33pub struct ScmRights<'a> {
34    data: &'a [u8],
35    offset: usize,
36}
37
38impl<'a> ScmRights<'a> {
39    pub(crate) fn new(data: &'a [u8]) -> Self {
40        ScmRights { data, offset: 0 }
41    }
42}
43
44impl Iterator for ScmRights<'_> {
45    type Item = OwnedFd;
46
47    fn next(&mut self) -> Option<Self::Item> {
48        let fd_size = mem::size_of::<RawFd>();
49        loop {
50            if self.offset + fd_size > self.data.len() {
51                return None;
52            }
53            let mut fd_bytes = [0u8; mem::size_of::<RawFd>()];
54            fd_bytes.copy_from_slice(&self.data[self.offset..self.offset + fd_size]);
55            self.offset += fd_size;
56            let raw = RawFd::from_ne_bytes(fd_bytes);
57
58            // The kernel never delivers negative fd values via SCM_RIGHTS;
59            // any negative is malformed input. Skip silently rather than
60            // tripping `OwnedFd::from_raw_fd`'s precondition (which panics
61            // under debug assertions and is UB to violate).
62            if raw < 0 {
63                continue;
64            }
65
66            // SAFETY: the kernel just delivered this fd to us via recvmsg
67            // SCM_RIGHTS; we wrap it in OwnedFd immediately and the caller
68            // owns it from this point. Caller MUST iterate exactly once —
69            // see type docs.
70            return Some(unsafe { OwnedFd::from_raw_fd(raw) });
71        }
72    }
73}
74
75/// Iterator over control messages in an ancillary buffer.
76pub struct Messages<'a> {
77    current: *const libc::cmsghdr,
78    msg: libc::msghdr,
79    _marker: PhantomData<&'a [u8]>,
80}
81
82impl<'a> Messages<'a> {
83    fn new(buffer: &'a [u8], length: usize) -> Self {
84        // SAFETY: zeroed msghdr followed by explicit field init.
85        let mut msg: libc::msghdr = unsafe { mem::zeroed() };
86        msg.msg_control = buffer.as_ptr() as *mut libc::c_void;
87        msg.msg_controllen = length as _;
88
89        // SAFETY: msg.msg_control points at `buffer` for `length` bytes;
90        // CMSG_FIRSTHDR walks that region per the cmsg(3) contract.
91        let current = unsafe { libc::CMSG_FIRSTHDR(&msg) };
92
93        Messages {
94            current,
95            msg,
96            _marker: PhantomData,
97        }
98    }
99}
100
101impl<'a> Iterator for Messages<'a> {
102    type Item = AncillaryData<'a>;
103
104    fn next(&mut self) -> Option<Self::Item> {
105        // Loop instead of recursing on unknown cmsg types: an adversarial
106        // peer could otherwise force unbounded recursion.
107        loop {
108            if self.current.is_null() {
109                return None;
110            }
111
112            // Compute buffer bounds once. Used to validate cmsg_len before
113            // calling CMSG_NXTHDR, which performs unchecked pointer
114            // arithmetic on (corrupted) cmsg_len in libc and would otherwise
115            // overflow on malformed input. Buffers reaching us from
116            // `recvmsg` are kernel-formatted, but defending against bogus
117            // input is cheap and protects fuzz/test/replay use cases.
118            let buf_start = self.msg.msg_control as usize;
119            #[allow(clippy::unnecessary_cast)]
120            let buf_end = buf_start.saturating_add(self.msg.msg_controllen as usize);
121            let cur_addr = self.current as usize;
122
123            // SAFETY: `current` is non-null and points inside the borrowed
124            // buffer (guaranteed by CMSG_FIRSTHDR/CMSG_NXTHDR contract);
125            // reading the header is sound.
126            #[allow(clippy::unnecessary_cast)]
127            // cmsg_len is size_t on Linux but socklen_t (u32) elsewhere
128            let (level, ty, data_ptr, data_len, well_formed) = unsafe {
129                let cmsg = &*self.current;
130                let data_ptr = libc::CMSG_DATA(self.current as *mut _);
131                let header_len = (data_ptr as usize).saturating_sub(cur_addr);
132                let total = cmsg.cmsg_len as usize;
133
134                // Bound `total` to the bytes remaining in the buffer from
135                // this cmsg's start. Anything claiming to extend past the
136                // buffer is malformed; we treat its data area as empty and
137                // refuse to walk further.
138                let remaining = buf_end.saturating_sub(cur_addr);
139                let well_formed = total >= header_len && total <= remaining;
140                let data_len = if well_formed { total - header_len } else { 0 };
141
142                (
143                    cmsg.cmsg_level,
144                    cmsg.cmsg_type,
145                    data_ptr,
146                    data_len,
147                    well_formed,
148                )
149            };
150
151            // Advance only if the current cmsg is well-formed: libc's
152            // CMSG_NXTHDR reads cmsg_len from the cmsghdr directly and would
153            // overflow pointer arithmetic on a corrupted value. If
154            // malformed, terminate the walk after handling the current
155            // entry's data slice.
156            self.current = if well_formed {
157                // SAFETY: cmsg_len fits in the buffer; CMSG_NXTHDR will
158                // either return a valid in-buffer pointer or null.
159                unsafe { libc::CMSG_NXTHDR(&self.msg, self.current) }
160            } else {
161                std::ptr::null()
162            };
163
164            if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
165                // SAFETY: data_ptr is in-buffer; data_len is bounded by
166                // the buffer end via the well-formed check above. Lifetime
167                // ties to the buffer borrowed by Messages<'a>.
168                let data: &'a [u8] = unsafe { std::slice::from_raw_parts(data_ptr, data_len) };
169                return Some(AncillaryData::ScmRights(ScmRights::new(data)));
170            }
171            // Unknown cmsg type — skip and continue walking. If we marked
172            // `current` null above (malformed), the next loop iteration
173            // returns None.
174        }
175    }
176}
177
178/// Buffer for building and parsing Unix socket ancillary data (control
179/// messages).
180///
181/// Used with `sendmsg`/`recvmsg` to pass file descriptors via `SCM_RIGHTS`.
182pub struct SocketAncillary<'a> {
183    pub(crate) buffer: &'a mut [u8],
184    pub(crate) length: usize,
185    pub(crate) truncated: bool,
186}
187
188impl<'a> SocketAncillary<'a> {
189    /// Create a new `SocketAncillary` backed by the given buffer.
190    pub fn new(buffer: &'a mut [u8]) -> Self {
191        SocketAncillary {
192            buffer,
193            length: 0,
194            truncated: false,
195        }
196    }
197
198    /// Minimum buffer size needed to send `num_fds` file descriptors.
199    pub fn buffer_size_for_rights(num_fds: usize) -> usize {
200        // SAFETY: CMSG_SPACE is a pure inline calculation.
201        unsafe { libc::CMSG_SPACE((num_fds * mem::size_of::<RawFd>()) as libc::c_uint) as usize }
202    }
203
204    /// Append file descriptors as an `SCM_RIGHTS` cmsg.
205    ///
206    /// `BorrowedFd` ensures the caller retains ownership of the fds.
207    pub fn add_fds(&mut self, fds: &[BorrowedFd<'_>]) -> Result<(), AncillaryError> {
208        let fd_bytes_len = fds.len() * mem::size_of::<RawFd>();
209        // SAFETY: pure inline calculation.
210        let space = unsafe { libc::CMSG_SPACE(fd_bytes_len as libc::c_uint) as usize };
211
212        let new_len = self.length.checked_add(space).ok_or(AncillaryError)?;
213        if new_len > self.buffer.len() {
214            return Err(AncillaryError);
215        }
216
217        // SAFETY: we walk the buffer with cmsg(3) macros and write a single
218        // cmsghdr + fd payload at the correct offset. The buffer is
219        // exclusively borrowed and large enough for `new_len` bytes.
220        unsafe {
221            let mut msg: libc::msghdr = mem::zeroed();
222            msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
223            msg.msg_controllen = new_len as _;
224
225            let cmsg = if self.length == 0 {
226                libc::CMSG_FIRSTHDR(&msg)
227            } else {
228                let mut walk_msg: libc::msghdr = mem::zeroed();
229                walk_msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
230                walk_msg.msg_controllen = self.length as _;
231
232                let mut cur = libc::CMSG_FIRSTHDR(&walk_msg);
233                while !cur.is_null() {
234                    let next = libc::CMSG_NXTHDR(&walk_msg, cur);
235                    if next.is_null() {
236                        break;
237                    }
238                    cur = next;
239                }
240                if cur.is_null() {
241                    libc::CMSG_FIRSTHDR(&msg)
242                } else {
243                    libc::CMSG_NXTHDR(&msg, cur)
244                }
245            };
246
247            if cmsg.is_null() {
248                return Err(AncillaryError);
249            }
250
251            (*cmsg).cmsg_level = libc::SOL_SOCKET;
252            (*cmsg).cmsg_type = libc::SCM_RIGHTS;
253            (*cmsg).cmsg_len = libc::CMSG_LEN(fd_bytes_len as libc::c_uint) as _;
254
255            // Write fds straight into the cmsg data area. `write_unaligned`
256            // because `CMSG_DATA` is not guaranteed to be `RawFd`-aligned.
257            let data_ptr = libc::CMSG_DATA(cmsg) as *mut RawFd;
258            for (i, fd) in fds.iter().enumerate() {
259                std::ptr::write_unaligned(data_ptr.add(i), fd.as_raw_fd());
260            }
261        }
262
263        self.length = new_len;
264        Ok(())
265    }
266
267    /// Iterate received ancillary data messages.
268    ///
269    /// Iterate exactly once; see [`ScmRights`].
270    pub fn messages(&self) -> Messages<'_> {
271        Messages::new(&self.buffer[..self.length], self.length)
272    }
273
274    /// Returns `true` if the ancillary data was truncated during receive.
275    ///
276    /// On platforms with `MSG_CMSG_CLOEXEC` (Linux/*BSD), truncation means
277    /// extra fds were discarded by the kernel and never entered our process.
278    /// On macOS, the kernel may have deposited fds beyond the buffer that
279    /// this crate cannot reach — **always size the buffer for the maximum
280    /// expected fd count on macOS**.
281    #[must_use]
282    pub fn is_truncated(&self) -> bool {
283        self.truncated
284    }
285
286    /// Clear the ancillary buffer for reuse.
287    pub fn clear(&mut self) {
288        self.length = 0;
289        self.truncated = false;
290    }
291}
292
293/// Internal entry point for fuzz harnesses. Walks an arbitrary byte buffer
294/// as if it were a kernel-formatted ancillary buffer.
295///
296/// **Not a stable API.** Hidden from rustdoc and not covered by the crate's
297/// semver guarantees.
298///
299/// # Safety
300///
301/// The iterator returned will produce `OwnedFd` values for any non-negative
302/// integer it finds in the SCM_RIGHTS data area. If those integers are not
303/// fds the caller exclusively owns, dropping the resulting `OwnedFd`s will
304/// close arbitrary descriptors in the process. Callers must either own
305/// every fd value present in `buf`, or wrap each yielded `OwnedFd` in
306/// `ManuallyDrop` before letting it drop.
307#[doc(hidden)]
308pub unsafe fn __fuzz_parse(buf: &[u8]) -> Messages<'_> {
309    Messages::new(buf, buf.len())
310}