posix_socket/
ancillary.rs

1// Copied from PR to the standard library.
2// PR: https://github.com/rust-lang/rust/pull/69864/files
3// File downloaded from: https://raw.githubusercontent.com/rust-lang/rust/20c88ddd5fe668b29e8fc2c3838710093e8eb94b/library/std/src/sys/unix/ext/net/ancillary.rs
4
5use core::convert::TryFrom;
6use core::marker::PhantomData;
7use core::mem::{size_of, zeroed};
8use core::ptr::read_unaligned;
9use core::slice::from_raw_parts;
10use std::os::unix::io::RawFd;
11
12#[cfg(any(target_os = "android", target_os = "linux",))]
13use libc::{gid_t, pid_t, uid_t};
14
15
16#[cfg(any(target_os = "android", all(target_os = "linux", target_env = "gnu")))]
17pub(crate) type CmsgLen = usize;
18
19#[cfg(any(
20	target_os = "dragonfly",
21	target_os = "emscripten",
22	target_os = "freebsd",
23	all(target_os = "linux", target_env = "musl",),
24	target_os = "netbsd",
25	target_os = "openbsd",
26))]
27pub(crate) type CmsgLen = libc::socklen_t;
28
29fn add_to_ancillary_data<T>(
30	buffer: &mut [u8],
31	length: &mut usize,
32	source: &[T],
33	cmsg_level: libc::c_int,
34	cmsg_type: libc::c_int,
35) -> bool {
36	let source_len = if let Some(source_len) = source.len().checked_mul(size_of::<T>()) {
37		if let Ok(source_len) = u32::try_from(source_len) {
38			source_len
39		} else {
40			return false;
41		}
42	} else {
43		return false;
44	};
45
46	unsafe {
47		let additional_space = libc::CMSG_SPACE(source_len) as usize;
48
49		let new_length = if let Some(new_length) = additional_space.checked_add(*length) {
50			new_length
51		} else {
52			return false;
53		};
54
55		if new_length > buffer.len() {
56			return false;
57		}
58
59		for byte in &mut buffer[*length..new_length] {
60			*byte = 0;
61		}
62
63		*length = new_length;
64
65		let mut msg: libc::msghdr = zeroed();
66		msg.msg_control = buffer.as_mut_ptr().cast();
67		msg.msg_controllen = *length as CmsgLen;
68
69		let mut cmsg = libc::CMSG_FIRSTHDR(&msg);
70		let mut previous_cmsg = cmsg;
71		while !cmsg.is_null() {
72			previous_cmsg = cmsg;
73			cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
74		}
75
76		if previous_cmsg.is_null() {
77			return false;
78		}
79
80		(*previous_cmsg).cmsg_level = cmsg_level;
81		(*previous_cmsg).cmsg_type = cmsg_type;
82		(*previous_cmsg).cmsg_len = libc::CMSG_LEN(source_len) as CmsgLen;
83
84		let data = libc::CMSG_DATA(previous_cmsg).cast();
85
86		libc::memcpy(data, source.as_ptr().cast(), source_len as usize);
87	}
88	true
89}
90
91struct AncillaryDataIter<'a, T> {
92	data: &'a [u8],
93	phantom: PhantomData<T>,
94}
95
96impl<'a, T> AncillaryDataIter<'a, T> {
97	/// Create `AncillaryDataIter` struct to iterate through the data unit in the control message.
98	///
99	/// # Safety
100	///
101	/// `data` must contain a valid control message.
102	unsafe fn new(data: &'a [u8]) -> AncillaryDataIter<'a, T> {
103		AncillaryDataIter { data, phantom: PhantomData }
104	}
105}
106
107impl<'a, T> Iterator for AncillaryDataIter<'a, T> {
108	type Item = T;
109
110	fn next(&mut self) -> Option<T> {
111		if size_of::<T>() <= self.data.len() {
112			unsafe {
113				let unit = read_unaligned(self.data.as_ptr().cast());
114				self.data = &self.data[size_of::<T>()..];
115				Some(unit)
116			}
117		} else {
118			None
119		}
120	}
121}
122
123/// Unix credential.
124#[cfg(any(doc, target_os = "android", target_os = "linux",))]
125#[derive(Clone)]
126pub struct SocketCred(libc::ucred);
127
128#[cfg(any(doc, target_os = "android", target_os = "linux",))]
129impl SocketCred {
130	/// Create a Unix credential struct.
131	///
132	/// PID, UID and GID is set to 0.
133	pub fn new() -> SocketCred {
134		SocketCred(libc::ucred { pid: 0, uid: 0, gid: 0 })
135	}
136
137	/// Set the PID.
138	pub fn set_pid(&mut self, pid: pid_t) {
139		self.0.pid = pid;
140	}
141
142	/// Get the current PID.
143	pub fn get_pid(&self) -> pid_t {
144		self.0.pid
145	}
146
147	/// Set the UID.
148	pub fn set_uid(&mut self, uid: uid_t) {
149		self.0.uid = uid;
150	}
151
152	/// Get the current UID.
153	pub fn get_uid(&self) -> uid_t {
154		self.0.uid
155	}
156
157	/// Set the GID.
158	pub fn set_gid(&mut self, gid: gid_t) {
159		self.0.gid = gid;
160	}
161
162	/// Get the current GID.
163	pub fn get_gid(&self) -> gid_t {
164		self.0.gid
165	}
166}
167
168/// This control message contains file descriptors.
169///
170/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`.
171pub struct ScmRights<'a>(AncillaryDataIter<'a, RawFd>);
172
173impl<'a> Iterator for ScmRights<'a> {
174	type Item = RawFd;
175
176	fn next(&mut self) -> Option<RawFd> {
177		self.0.next()
178	}
179}
180
181/// This control message contains unix credentials.
182///
183/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_CREDENTIALS` or `SCM_CREDS`.
184#[cfg(any(doc, target_os = "android", target_os = "linux",))]
185pub struct ScmCredentials<'a>(AncillaryDataIter<'a, libc::ucred>);
186
187#[cfg(any(doc, target_os = "android", target_os = "linux",))]
188impl<'a> Iterator for ScmCredentials<'a> {
189	type Item = SocketCred;
190
191	fn next(&mut self) -> Option<SocketCred> {
192		Some(SocketCred(self.0.next()?))
193	}
194}
195
196/// The error type which is returned from parsing the type a control message.
197#[non_exhaustive]
198#[derive(Debug)]
199pub enum AncillaryError {
200	Unknown { cmsg_level: i32, cmsg_type: i32 },
201}
202
203/// This enum represent one control message of variable type.
204pub enum AncillaryData<'a> {
205	ScmRights(ScmRights<'a>),
206	#[cfg(any(doc, target_os = "android", target_os = "linux",))]
207	ScmCredentials(ScmCredentials<'a>),
208}
209
210impl<'a> AncillaryData<'a> {
211	/// Create a `AncillaryData::ScmRights` variant.
212	///
213	/// # Safety
214	///
215	/// `data` must contain a valid control message and the control message must be type of
216	/// `SOL_SOCKET` and level of `SCM_RIGHTS`.
217	unsafe fn as_rights(data: &'a [u8]) -> Self {
218		let ancillary_data_iter = AncillaryDataIter::new(data);
219		let scm_rights = ScmRights(ancillary_data_iter);
220		AncillaryData::ScmRights(scm_rights)
221	}
222
223	/// Create a `AncillaryData::ScmCredentials` variant.
224	///
225	/// # Safety
226	///
227	/// `data` must contain a valid control message and the control message must be type of
228	/// `SOL_SOCKET` and level of `SCM_CREDENTIALS` or `SCM_CREDENTIALS`.
229	#[cfg(any(doc, target_os = "android", target_os = "linux",))]
230	unsafe fn as_credentials(data: &'a [u8]) -> Self {
231		let ancillary_data_iter = AncillaryDataIter::new(data);
232		let scm_credentials = ScmCredentials(ancillary_data_iter);
233		AncillaryData::ScmCredentials(scm_credentials)
234	}
235
236	fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Result<Self, AncillaryError> {
237		unsafe {
238			let cmsg_len_zero = libc::CMSG_LEN(0) as CmsgLen;
239			let data_len = (*cmsg).cmsg_len - cmsg_len_zero;
240			let data = libc::CMSG_DATA(cmsg).cast();
241			let data = from_raw_parts(data, data_len as usize);
242
243			match (*cmsg).cmsg_level {
244				libc::SOL_SOCKET => match (*cmsg).cmsg_type {
245					libc::SCM_RIGHTS => Ok(AncillaryData::as_rights(data)),
246					#[cfg(any(target_os = "android", target_os = "linux",))]
247					libc::SCM_CREDENTIALS => Ok(AncillaryData::as_credentials(data)),
248					cmsg_type => {
249						Err(AncillaryError::Unknown { cmsg_level: libc::SOL_SOCKET, cmsg_type })
250					}
251				},
252				cmsg_level => {
253					Err(AncillaryError::Unknown { cmsg_level, cmsg_type: (*cmsg).cmsg_type })
254				}
255			}
256		}
257	}
258}
259
260/// This struct is used to iterate through the control messages.
261pub struct Messages<'a> {
262	buffer: &'a [u8],
263	current: Option<&'a libc::cmsghdr>,
264}
265
266impl<'a> Iterator for Messages<'a> {
267	type Item = Result<AncillaryData<'a>, AncillaryError>;
268
269	fn next(&mut self) -> Option<Self::Item> {
270		unsafe {
271			let mut msg: libc::msghdr = zeroed();
272			msg.msg_control = self.buffer.as_ptr() as *mut _;
273			msg.msg_controllen = self.buffer.len() as CmsgLen;
274
275			let cmsg = if let Some(current) = self.current {
276				libc::CMSG_NXTHDR(&msg, current)
277			} else {
278				libc::CMSG_FIRSTHDR(&msg)
279			};
280
281			let cmsg = cmsg.as_ref()?;
282			self.current = Some(cmsg);
283			let ancillary_result = AncillaryData::try_from_cmsghdr(cmsg);
284			Some(ancillary_result)
285		}
286	}
287}
288
289/// A Unix socket Ancillary data struct.
290#[derive(Debug)]
291pub struct SocketAncillary<'a> {
292	pub(crate) buffer: &'a mut [u8],
293	pub(crate) length: usize,
294	pub(crate) truncated: bool,
295}
296
297impl<'a> SocketAncillary<'a> {
298	/// Create an ancillary data with the given buffer.
299	///
300	/// # Example
301	///
302	/// ```no_run
303	/// use posix_socket::ancillary::SocketAncillary;
304	/// let mut ancillary_buffer = [0; 128];
305	/// let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
306	/// ```
307	pub fn new(buffer: &'a mut [u8]) -> Self {
308		SocketAncillary { buffer, length: 0, truncated: false }
309	}
310
311	/// Returns the capacity of the buffer.
312	pub fn capacity(&self) -> usize {
313		self.buffer.len()
314	}
315
316	/// Returns the number of used bytes.
317	pub fn len(&self) -> usize {
318		self.length
319	}
320
321	/// Returns the iterator of the control messages.
322	pub fn messages(&self) -> Messages<'_> {
323		Messages { buffer: &self.buffer[..self.length], current: None }
324	}
325
326	/// Is `true` if during a recv operation the ancillary was truncated.
327	pub fn truncated(&self) -> bool {
328		self.truncated
329	}
330
331	/// Add file descriptors to the ancillary data.
332	///
333	/// The function returns `true` if there was enough space in the buffer.
334	/// If there was not enough space then no file descriptors was appended.
335	/// Technically, that means this operation adds a control message with the level `SOL_SOCKET`
336	/// and type `SCM_RIGHTS`.
337	pub fn add_fds(&mut self, fds: &[RawFd]) -> bool {
338		self.truncated = false;
339		add_to_ancillary_data(
340			&mut self.buffer,
341			&mut self.length,
342			fds,
343			libc::SOL_SOCKET,
344			libc::SCM_RIGHTS,
345		)
346	}
347
348	/// Add credentials to the ancillary data.
349	///
350	/// The function returns `true` if there was enough space in the buffer.
351	/// If there was not enough space then no credentials was appended.
352	/// Technically, that means this operation adds a control message with the level `SOL_SOCKET`
353	/// and type `SCM_CREDENTIALS` or `SCM_CREDS`.
354	///
355	#[cfg(any(doc, target_os = "android", target_os = "linux",))]
356	pub fn add_creds(&mut self, creds: &[SocketCred]) -> bool {
357		self.truncated = false;
358		add_to_ancillary_data(
359			&mut self.buffer,
360			&mut self.length,
361			creds,
362			libc::SOL_SOCKET,
363			libc::SCM_CREDENTIALS,
364		)
365	}
366
367	/// Clears the ancillary data, removing all values.
368	pub fn clear(&mut self) {
369		self.length = 0;
370		self.truncated = false;
371	}
372}