1use core::convert::TryFrom;
6use core::marker::PhantomData;
7use core::mem::{size_of, zeroed};
8use core::ptr::read_unaligned;
9use core::slice::from_raw_parts;
10use std::os::unix::io::RawFd;
11
12#[cfg(any(target_os = "android", target_os = "linux",))]
13use libc::{gid_t, pid_t, uid_t};
14
15
16#[cfg(any(target_os = "android", all(target_os = "linux", target_env = "gnu")))]
17pub(crate) type CmsgLen = usize;
18
19#[cfg(any(
20 target_os = "dragonfly",
21 target_os = "emscripten",
22 target_os = "freebsd",
23 all(target_os = "linux", target_env = "musl",),
24 target_os = "netbsd",
25 target_os = "openbsd",
26))]
27pub(crate) type CmsgLen = libc::socklen_t;
28
29fn add_to_ancillary_data<T>(
30 buffer: &mut [u8],
31 length: &mut usize,
32 source: &[T],
33 cmsg_level: libc::c_int,
34 cmsg_type: libc::c_int,
35) -> bool {
36 let source_len = if let Some(source_len) = source.len().checked_mul(size_of::<T>()) {
37 if let Ok(source_len) = u32::try_from(source_len) {
38 source_len
39 } else {
40 return false;
41 }
42 } else {
43 return false;
44 };
45
46 unsafe {
47 let additional_space = libc::CMSG_SPACE(source_len) as usize;
48
49 let new_length = if let Some(new_length) = additional_space.checked_add(*length) {
50 new_length
51 } else {
52 return false;
53 };
54
55 if new_length > buffer.len() {
56 return false;
57 }
58
59 for byte in &mut buffer[*length..new_length] {
60 *byte = 0;
61 }
62
63 *length = new_length;
64
65 let mut msg: libc::msghdr = zeroed();
66 msg.msg_control = buffer.as_mut_ptr().cast();
67 msg.msg_controllen = *length as CmsgLen;
68
69 let mut cmsg = libc::CMSG_FIRSTHDR(&msg);
70 let mut previous_cmsg = cmsg;
71 while !cmsg.is_null() {
72 previous_cmsg = cmsg;
73 cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
74 }
75
76 if previous_cmsg.is_null() {
77 return false;
78 }
79
80 (*previous_cmsg).cmsg_level = cmsg_level;
81 (*previous_cmsg).cmsg_type = cmsg_type;
82 (*previous_cmsg).cmsg_len = libc::CMSG_LEN(source_len) as CmsgLen;
83
84 let data = libc::CMSG_DATA(previous_cmsg).cast();
85
86 libc::memcpy(data, source.as_ptr().cast(), source_len as usize);
87 }
88 true
89}
90
91struct AncillaryDataIter<'a, T> {
92 data: &'a [u8],
93 phantom: PhantomData<T>,
94}
95
96impl<'a, T> AncillaryDataIter<'a, T> {
97 unsafe fn new(data: &'a [u8]) -> AncillaryDataIter<'a, T> {
103 AncillaryDataIter { data, phantom: PhantomData }
104 }
105}
106
107impl<'a, T> Iterator for AncillaryDataIter<'a, T> {
108 type Item = T;
109
110 fn next(&mut self) -> Option<T> {
111 if size_of::<T>() <= self.data.len() {
112 unsafe {
113 let unit = read_unaligned(self.data.as_ptr().cast());
114 self.data = &self.data[size_of::<T>()..];
115 Some(unit)
116 }
117 } else {
118 None
119 }
120 }
121}
122
123#[cfg(any(doc, target_os = "android", target_os = "linux",))]
125#[derive(Clone)]
126pub struct SocketCred(libc::ucred);
127
128#[cfg(any(doc, target_os = "android", target_os = "linux",))]
129impl SocketCred {
130 pub fn new() -> SocketCred {
134 SocketCred(libc::ucred { pid: 0, uid: 0, gid: 0 })
135 }
136
137 pub fn set_pid(&mut self, pid: pid_t) {
139 self.0.pid = pid;
140 }
141
142 pub fn get_pid(&self) -> pid_t {
144 self.0.pid
145 }
146
147 pub fn set_uid(&mut self, uid: uid_t) {
149 self.0.uid = uid;
150 }
151
152 pub fn get_uid(&self) -> uid_t {
154 self.0.uid
155 }
156
157 pub fn set_gid(&mut self, gid: gid_t) {
159 self.0.gid = gid;
160 }
161
162 pub fn get_gid(&self) -> gid_t {
164 self.0.gid
165 }
166}
167
168pub struct ScmRights<'a>(AncillaryDataIter<'a, RawFd>);
172
173impl<'a> Iterator for ScmRights<'a> {
174 type Item = RawFd;
175
176 fn next(&mut self) -> Option<RawFd> {
177 self.0.next()
178 }
179}
180
181#[cfg(any(doc, target_os = "android", target_os = "linux",))]
185pub struct ScmCredentials<'a>(AncillaryDataIter<'a, libc::ucred>);
186
187#[cfg(any(doc, target_os = "android", target_os = "linux",))]
188impl<'a> Iterator for ScmCredentials<'a> {
189 type Item = SocketCred;
190
191 fn next(&mut self) -> Option<SocketCred> {
192 Some(SocketCred(self.0.next()?))
193 }
194}
195
196#[non_exhaustive]
198#[derive(Debug)]
199pub enum AncillaryError {
200 Unknown { cmsg_level: i32, cmsg_type: i32 },
201}
202
203pub enum AncillaryData<'a> {
205 ScmRights(ScmRights<'a>),
206 #[cfg(any(doc, target_os = "android", target_os = "linux",))]
207 ScmCredentials(ScmCredentials<'a>),
208}
209
210impl<'a> AncillaryData<'a> {
211 unsafe fn as_rights(data: &'a [u8]) -> Self {
218 let ancillary_data_iter = AncillaryDataIter::new(data);
219 let scm_rights = ScmRights(ancillary_data_iter);
220 AncillaryData::ScmRights(scm_rights)
221 }
222
223 #[cfg(any(doc, target_os = "android", target_os = "linux",))]
230 unsafe fn as_credentials(data: &'a [u8]) -> Self {
231 let ancillary_data_iter = AncillaryDataIter::new(data);
232 let scm_credentials = ScmCredentials(ancillary_data_iter);
233 AncillaryData::ScmCredentials(scm_credentials)
234 }
235
236 fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Result<Self, AncillaryError> {
237 unsafe {
238 let cmsg_len_zero = libc::CMSG_LEN(0) as CmsgLen;
239 let data_len = (*cmsg).cmsg_len - cmsg_len_zero;
240 let data = libc::CMSG_DATA(cmsg).cast();
241 let data = from_raw_parts(data, data_len as usize);
242
243 match (*cmsg).cmsg_level {
244 libc::SOL_SOCKET => match (*cmsg).cmsg_type {
245 libc::SCM_RIGHTS => Ok(AncillaryData::as_rights(data)),
246 #[cfg(any(target_os = "android", target_os = "linux",))]
247 libc::SCM_CREDENTIALS => Ok(AncillaryData::as_credentials(data)),
248 cmsg_type => {
249 Err(AncillaryError::Unknown { cmsg_level: libc::SOL_SOCKET, cmsg_type })
250 }
251 },
252 cmsg_level => {
253 Err(AncillaryError::Unknown { cmsg_level, cmsg_type: (*cmsg).cmsg_type })
254 }
255 }
256 }
257 }
258}
259
260pub struct Messages<'a> {
262 buffer: &'a [u8],
263 current: Option<&'a libc::cmsghdr>,
264}
265
266impl<'a> Iterator for Messages<'a> {
267 type Item = Result<AncillaryData<'a>, AncillaryError>;
268
269 fn next(&mut self) -> Option<Self::Item> {
270 unsafe {
271 let mut msg: libc::msghdr = zeroed();
272 msg.msg_control = self.buffer.as_ptr() as *mut _;
273 msg.msg_controllen = self.buffer.len() as CmsgLen;
274
275 let cmsg = if let Some(current) = self.current {
276 libc::CMSG_NXTHDR(&msg, current)
277 } else {
278 libc::CMSG_FIRSTHDR(&msg)
279 };
280
281 let cmsg = cmsg.as_ref()?;
282 self.current = Some(cmsg);
283 let ancillary_result = AncillaryData::try_from_cmsghdr(cmsg);
284 Some(ancillary_result)
285 }
286 }
287}
288
289#[derive(Debug)]
291pub struct SocketAncillary<'a> {
292 pub(crate) buffer: &'a mut [u8],
293 pub(crate) length: usize,
294 pub(crate) truncated: bool,
295}
296
297impl<'a> SocketAncillary<'a> {
298 pub fn new(buffer: &'a mut [u8]) -> Self {
308 SocketAncillary { buffer, length: 0, truncated: false }
309 }
310
311 pub fn capacity(&self) -> usize {
313 self.buffer.len()
314 }
315
316 pub fn len(&self) -> usize {
318 self.length
319 }
320
321 pub fn messages(&self) -> Messages<'_> {
323 Messages { buffer: &self.buffer[..self.length], current: None }
324 }
325
326 pub fn truncated(&self) -> bool {
328 self.truncated
329 }
330
331 pub fn add_fds(&mut self, fds: &[RawFd]) -> bool {
338 self.truncated = false;
339 add_to_ancillary_data(
340 &mut self.buffer,
341 &mut self.length,
342 fds,
343 libc::SOL_SOCKET,
344 libc::SCM_RIGHTS,
345 )
346 }
347
348 #[cfg(any(doc, target_os = "android", target_os = "linux",))]
356 pub fn add_creds(&mut self, creds: &[SocketCred]) -> bool {
357 self.truncated = false;
358 add_to_ancillary_data(
359 &mut self.buffer,
360 &mut self.length,
361 creds,
362 libc::SOL_SOCKET,
363 libc::SCM_CREDENTIALS,
364 )
365 }
366
367 pub fn clear(&mut self) {
369 self.length = 0;
370 self.truncated = false;
371 }
372}