unix_ancillary/ancillary.rs
1use std::marker::PhantomData;
2use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
3use std::{fmt, mem};
4
5/// Error returned when the ancillary buffer is too small.
6#[derive(Debug, Clone)]
7pub struct AncillaryError;
8
9impl fmt::Display for AncillaryError {
10 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11 write!(f, "ancillary buffer too small")
12 }
13}
14
15impl std::error::Error for AncillaryError {}
16
17/// Received ancillary data from a Unix socket.
18pub enum AncillaryData<'a> {
19 /// File descriptors received via `SCM_RIGHTS`.
20 ScmRights(ScmRights<'a>),
21}
22
23/// Iterator over file descriptors received via `SCM_RIGHTS`.
24///
25/// Each yielded `OwnedFd` takes ownership of one received descriptor and
26/// closes it on drop.
27///
28/// # Important
29///
30/// Iterate this exactly once. Iterating the same `Messages`/`ScmRights` view
31/// twice would manufacture two `OwnedFd`s for the same raw fd, leading to a
32/// double-close.
33pub struct ScmRights<'a> {
34 data: &'a [u8],
35 offset: usize,
36}
37
38impl<'a> ScmRights<'a> {
39 pub(crate) fn new(data: &'a [u8]) -> Self {
40 ScmRights { data, offset: 0 }
41 }
42}
43
44impl Iterator for ScmRights<'_> {
45 type Item = OwnedFd;
46
47 fn next(&mut self) -> Option<Self::Item> {
48 let fd_size = mem::size_of::<RawFd>();
49 loop {
50 if self.offset + fd_size > self.data.len() {
51 return None;
52 }
53 let mut fd_bytes = [0u8; mem::size_of::<RawFd>()];
54 fd_bytes.copy_from_slice(&self.data[self.offset..self.offset + fd_size]);
55 self.offset += fd_size;
56 let raw = RawFd::from_ne_bytes(fd_bytes);
57
58 // The kernel never delivers negative fd values via SCM_RIGHTS;
59 // any negative is malformed input. Skip silently rather than
60 // tripping `OwnedFd::from_raw_fd`'s precondition (which panics
61 // under debug assertions and is UB to violate).
62 if raw < 0 {
63 continue;
64 }
65
66 // SAFETY: the kernel just delivered this fd to us via recvmsg
67 // SCM_RIGHTS; we wrap it in OwnedFd immediately and the caller
68 // owns it from this point. Caller MUST iterate exactly once —
69 // see type docs.
70 return Some(unsafe { OwnedFd::from_raw_fd(raw) });
71 }
72 }
73}
74
75/// Iterator over control messages in an ancillary buffer.
76pub struct Messages<'a> {
77 current: *const libc::cmsghdr,
78 msg: libc::msghdr,
79 _marker: PhantomData<&'a [u8]>,
80}
81
82impl<'a> Messages<'a> {
83 fn new(buffer: &'a [u8], length: usize) -> Self {
84 // SAFETY: zeroed msghdr followed by explicit field init.
85 let mut msg: libc::msghdr = unsafe { mem::zeroed() };
86 msg.msg_control = buffer.as_ptr() as *mut libc::c_void;
87 msg.msg_controllen = length as _;
88
89 // SAFETY: msg.msg_control points at `buffer` for `length` bytes;
90 // CMSG_FIRSTHDR walks that region per the cmsg(3) contract.
91 let current = unsafe { libc::CMSG_FIRSTHDR(&msg) };
92
93 Messages {
94 current,
95 msg,
96 _marker: PhantomData,
97 }
98 }
99}
100
101impl<'a> Iterator for Messages<'a> {
102 type Item = AncillaryData<'a>;
103
104 fn next(&mut self) -> Option<Self::Item> {
105 // Loop instead of recursing on unknown cmsg types: an adversarial
106 // peer could otherwise force unbounded recursion.
107 loop {
108 if self.current.is_null() {
109 return None;
110 }
111
112 // Compute buffer bounds once. Used to validate cmsg_len before
113 // calling CMSG_NXTHDR, which performs unchecked pointer
114 // arithmetic on (corrupted) cmsg_len in libc and would otherwise
115 // overflow on malformed input. Buffers reaching us from
116 // `recvmsg` are kernel-formatted, but defending against bogus
117 // input is cheap and protects fuzz/test/replay use cases.
118 let buf_start = self.msg.msg_control as usize;
119 #[allow(clippy::unnecessary_cast)]
120 let buf_end = buf_start.saturating_add(self.msg.msg_controllen as usize);
121 let cur_addr = self.current as usize;
122
123 // SAFETY: `current` is non-null and points inside the borrowed
124 // buffer (guaranteed by CMSG_FIRSTHDR/CMSG_NXTHDR contract);
125 // reading the header is sound.
126 #[allow(clippy::unnecessary_cast)]
127 // cmsg_len is size_t on Linux but socklen_t (u32) elsewhere
128 let (level, ty, data_ptr, data_len, well_formed) = unsafe {
129 let cmsg = &*self.current;
130 let data_ptr = libc::CMSG_DATA(self.current as *mut _);
131 let header_len = (data_ptr as usize).saturating_sub(cur_addr);
132 let total = cmsg.cmsg_len as usize;
133
134 // Bound `total` to the bytes remaining in the buffer from
135 // this cmsg's start. Anything claiming to extend past the
136 // buffer is malformed; we treat its data area as empty and
137 // refuse to walk further.
138 let remaining = buf_end.saturating_sub(cur_addr);
139 let well_formed = total >= header_len && total <= remaining;
140 let data_len = if well_formed { total - header_len } else { 0 };
141
142 (
143 cmsg.cmsg_level,
144 cmsg.cmsg_type,
145 data_ptr,
146 data_len,
147 well_formed,
148 )
149 };
150
151 // Advance only if the current cmsg is well-formed: libc's
152 // CMSG_NXTHDR reads cmsg_len from the cmsghdr directly and would
153 // overflow pointer arithmetic on a corrupted value. If
154 // malformed, terminate the walk after handling the current
155 // entry's data slice.
156 self.current = if well_formed {
157 // SAFETY: cmsg_len fits in the buffer; CMSG_NXTHDR will
158 // either return a valid in-buffer pointer or null.
159 unsafe { libc::CMSG_NXTHDR(&self.msg, self.current) }
160 } else {
161 std::ptr::null()
162 };
163
164 if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
165 // SAFETY: data_ptr is in-buffer; data_len is bounded by
166 // the buffer end via the well-formed check above. Lifetime
167 // ties to the buffer borrowed by Messages<'a>.
168 let data: &'a [u8] = unsafe { std::slice::from_raw_parts(data_ptr, data_len) };
169 return Some(AncillaryData::ScmRights(ScmRights::new(data)));
170 }
171 // Unknown cmsg type — skip and continue walking. If we marked
172 // `current` null above (malformed), the next loop iteration
173 // returns None.
174 }
175 }
176}
177
178/// Buffer for building and parsing Unix socket ancillary data (control
179/// messages).
180///
181/// Used with `sendmsg`/`recvmsg` to pass file descriptors via `SCM_RIGHTS`.
182pub struct SocketAncillary<'a> {
183 pub(crate) buffer: &'a mut [u8],
184 pub(crate) length: usize,
185 pub(crate) truncated: bool,
186}
187
188impl<'a> SocketAncillary<'a> {
189 /// Create a new `SocketAncillary` backed by the given buffer.
190 pub fn new(buffer: &'a mut [u8]) -> Self {
191 SocketAncillary {
192 buffer,
193 length: 0,
194 truncated: false,
195 }
196 }
197
198 /// Minimum buffer size needed to send `num_fds` file descriptors.
199 pub fn buffer_size_for_rights(num_fds: usize) -> usize {
200 // SAFETY: CMSG_SPACE is a pure inline calculation.
201 unsafe { libc::CMSG_SPACE((num_fds * mem::size_of::<RawFd>()) as libc::c_uint) as usize }
202 }
203
204 /// Append file descriptors as an `SCM_RIGHTS` cmsg.
205 ///
206 /// `BorrowedFd` ensures the caller retains ownership of the fds.
207 pub fn add_fds(&mut self, fds: &[BorrowedFd<'_>]) -> Result<(), AncillaryError> {
208 let fd_bytes_len = fds.len() * mem::size_of::<RawFd>();
209 // SAFETY: pure inline calculation.
210 let space = unsafe { libc::CMSG_SPACE(fd_bytes_len as libc::c_uint) as usize };
211
212 let new_len = self.length.checked_add(space).ok_or(AncillaryError)?;
213 if new_len > self.buffer.len() {
214 return Err(AncillaryError);
215 }
216
217 // SAFETY: we walk the buffer with cmsg(3) macros and write a single
218 // cmsghdr + fd payload at the correct offset. The buffer is
219 // exclusively borrowed and large enough for `new_len` bytes.
220 unsafe {
221 let mut msg: libc::msghdr = mem::zeroed();
222 msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
223 msg.msg_controllen = new_len as _;
224
225 let cmsg = if self.length == 0 {
226 libc::CMSG_FIRSTHDR(&msg)
227 } else {
228 let mut walk_msg: libc::msghdr = mem::zeroed();
229 walk_msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
230 walk_msg.msg_controllen = self.length as _;
231
232 let mut cur = libc::CMSG_FIRSTHDR(&walk_msg);
233 while !cur.is_null() {
234 let next = libc::CMSG_NXTHDR(&walk_msg, cur);
235 if next.is_null() {
236 break;
237 }
238 cur = next;
239 }
240 if cur.is_null() {
241 libc::CMSG_FIRSTHDR(&msg)
242 } else {
243 libc::CMSG_NXTHDR(&msg, cur)
244 }
245 };
246
247 if cmsg.is_null() {
248 return Err(AncillaryError);
249 }
250
251 (*cmsg).cmsg_level = libc::SOL_SOCKET;
252 (*cmsg).cmsg_type = libc::SCM_RIGHTS;
253 (*cmsg).cmsg_len = libc::CMSG_LEN(fd_bytes_len as libc::c_uint) as _;
254
255 // Write fds straight into the cmsg data area. `write_unaligned`
256 // because `CMSG_DATA` is not guaranteed to be `RawFd`-aligned.
257 let data_ptr = libc::CMSG_DATA(cmsg) as *mut RawFd;
258 for (i, fd) in fds.iter().enumerate() {
259 std::ptr::write_unaligned(data_ptr.add(i), fd.as_raw_fd());
260 }
261 }
262
263 self.length = new_len;
264 Ok(())
265 }
266
267 /// Iterate received ancillary data messages.
268 ///
269 /// Iterate exactly once; see [`ScmRights`].
270 pub fn messages(&self) -> Messages<'_> {
271 Messages::new(&self.buffer[..self.length], self.length)
272 }
273
274 /// Returns `true` if the ancillary data was truncated during receive.
275 ///
276 /// On platforms with `MSG_CMSG_CLOEXEC` (Linux/*BSD), truncation means
277 /// extra fds were discarded by the kernel and never entered our process.
278 /// On macOS, the kernel may have deposited fds beyond the buffer that
279 /// this crate cannot reach — **always size the buffer for the maximum
280 /// expected fd count on macOS**.
281 #[must_use]
282 pub fn is_truncated(&self) -> bool {
283 self.truncated
284 }
285
286 /// Clear the ancillary buffer for reuse.
287 pub fn clear(&mut self) {
288 self.length = 0;
289 self.truncated = false;
290 }
291}
292
293/// Internal entry point for fuzz harnesses. Walks an arbitrary byte buffer
294/// as if it were a kernel-formatted ancillary buffer.
295///
296/// **Not a stable API.** Hidden from rustdoc and not covered by the crate's
297/// semver guarantees.
298///
299/// # Safety
300///
301/// The iterator returned will produce `OwnedFd` values for any non-negative
302/// integer it finds in the SCM_RIGHTS data area. If those integers are not
303/// fds the caller exclusively owns, dropping the resulting `OwnedFd`s will
304/// close arbitrary descriptors in the process. Callers must either own
305/// every fd value present in `buf`, or wrap each yielded `OwnedFd` in
306/// `ManuallyDrop` before letting it drop.
307#[doc(hidden)]
308pub unsafe fn __fuzz_parse(buf: &[u8]) -> Messages<'_> {
309 Messages::new(buf, buf.len())
310}