1use std::{
8 io,
9 marker::PhantomData,
10 os::{
11 fd::{AsFd, BorrowedFd, OwnedFd},
12 unix::net::UnixStream,
13 },
14};
15
16pub struct Channel<TX, RX> {
24 stream: UnixStream,
25 _marker: PhantomData<(fn(TX) -> RX, fn(RX) -> TX)>,
26}
27
28pub fn channel<TX, RX>() -> io::Result<(Channel<TX, RX>, Channel<RX, TX>)> {
29 let (a, b) = UnixStream::pair()?;
30 Ok((
31 Channel {
32 stream: a,
33 _marker: PhantomData,
34 },
35 Channel {
36 stream: b,
37 _marker: PhantomData,
38 },
39 ))
40}
41
42impl<TX, RX> Clone for Channel<TX, RX> {
43 fn clone(&self) -> Self {
44 Self {
45 stream: self.stream.try_clone().unwrap(),
46 _marker: PhantomData,
47 }
48 }
49}
50
51#[cfg(use_unstable_unix_socket_ancillary_data_2021)]
60mod sys {
61 use super::*;
62 use std::os::fd::FromRawFd;
63 use std::os::unix::net::{AncillaryData, SocketAncillary};
64
65 pub(super) fn stream_sendmsg<const FD_LEN: usize>(
66 stream: &UnixStream,
67 bytes: io::IoSlice<'_>,
68 fds: &[BorrowedFd<'_>; FD_LEN],
69 ) -> io::Result<()> {
70 let mut ancillary_buffer = [0; 64];
71 let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
72 if !ancillary.add_fds(unsafe { &*(fds as *const [BorrowedFd<'_>] as *const [i32]) }) {
73 return Err(io::Error::other(format!(
74 "failed to send {FD_LEN} file descriptors: \
75 the resulting cmsg doesn't fit in {} bytes",
76 ancillary.capacity()
77 )));
78 }
79 let written_len = stream.send_vectored_with_ancillary(&[bytes], &mut ancillary)?;
80 if written_len != bytes.len() {
81 return Err(io::Error::other(format!(
82 "partial write (only {written_len} out of {})",
83 bytes.len()
84 )));
85 }
86 Ok(())
87 }
88
89 pub(super) fn stream_recvmsg<const FD_LEN: usize>(
90 stream: &UnixStream,
91 bytes: io::IoSliceMut<'_>,
92 ) -> io::Result<[OwnedFd; FD_LEN]> {
93 let mut ancillary_buffer = [0; 64];
94 let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
95 let expected_len = bytes.len();
96 let read_len = stream.recv_vectored_with_ancillary(&mut [bytes], &mut ancillary)?;
97 let partial_read = read_len != expected_len;
98 let (anciliary_truncated, anciliary_capacity) =
99 (ancillary.truncated(), ancillary.capacity());
100
101 let mut errors = vec![];
107 let mut accepted_fds = [(); FD_LEN].map(|()| None);
108 let mut accepted_fd_count = 0;
109 for cmsg in ancillary.messages() {
110 match cmsg {
111 Err(err) => errors.push(format!("{err:?}")),
112 Ok(AncillaryData::ScmRights(raw_fds)) => {
113 let is_first_scm_rights = accepted_fd_count == 0;
114 for raw_fd in raw_fds {
115 if raw_fd == -1 {
116 errors.push("invalid fd (-1) received".into());
117 continue;
118 }
119 let fd = unsafe { OwnedFd::from_raw_fd(raw_fd) };
122 if is_first_scm_rights {
123 let i = accepted_fd_count;
125 accepted_fd_count += 1;
126 if let Some(slot) = accepted_fds.get_mut(i) {
127 *slot = Some(fd);
128 }
129 }
130 }
131 if !is_first_scm_rights {
132 errors.push("received more than one SCM_RIGHTS cmsg".into());
133 }
134 }
135 Ok(AncillaryData::ScmCredentials(_)) => {
136 errors.push("received unexpected SCM_CREDS-like cmsg".into());
137 }
138 }
139 }
140 if accepted_fd_count != FD_LEN {
141 errors.push(format!(
142 "wrong number of received fds: expected {FD_LEN}, got {accepted_fd_count}"
143 ))
144 }
145
146 if partial_read {
147 return Err(io::Error::other(format!(
148 "partial read: only {read_len} out of {expected_len}"
149 )));
150 }
151 if anciliary_truncated {
152 return Err(io::Error::other(format!(
153 "truncated anciliary buffer: received cmsg doesn't fit in {anciliary_capacity} bytes"
154 )));
155 }
156
157 if errors.is_empty() {
158 Ok(accepted_fds.map(Option::unwrap))
159 } else {
160 Err(io::Error::other(if errors.len() == 1 {
161 errors.pop().unwrap()
162 } else {
163 format!("errors during receiving:\n {}", errors.join("\n "))
164 }))
165 }
166 }
167}
168#[cfg(not(use_unstable_unix_socket_ancillary_data_2021))]
169mod sys {
170 #![allow(non_camel_case_types)]
171
172 fn io_error_other(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
174 io::Error::new(io::ErrorKind::Other, error)
175 }
176
177 use super::*;
178 use std::{
179 ffi::{c_int, c_void},
180 ptr,
181 };
182
183 type socklen_t = u32;
184
185 #[repr(C)]
186 struct msghdr<IOV> {
187 msg_name: *mut c_void,
188 msg_namelen: socklen_t,
189 msg_iov: *mut IOV,
190 msg_iovlen: usize,
191 msg_control: *mut c_void,
192 msg_controllen: usize,
193 msg_flags: c_int,
194 }
195
196 const SOL_SOCKET: c_int = 1;
197 const SCM_RIGHTS: c_int = 1;
198
199 #[repr(C)]
200 struct cmsghdr {
201 cmsg_len: usize,
202 cmsg_level: c_int,
203 cmsg_type: c_int,
204 }
205 const _: () = assert!(std::mem::size_of::<cmsghdr>() % std::mem::size_of::<usize>() == 0);
206
207 extern "C" {
208 fn sendmsg(
209 sockfd: BorrowedFd<'_>,
210 msg: *const msghdr<io::IoSlice<'_>>,
211 flags: c_int,
212 ) -> isize;
213 fn recvmsg(
214 sockfd: BorrowedFd<'_>,
215 msg: *mut msghdr<io::IoSliceMut<'_>>,
216 flags: c_int,
217 ) -> isize;
218 }
219
220 #[repr(C)]
221 struct CMsgBuf<FD, const FD_LEN: usize> {
222 header: cmsghdr,
223 fds: [FD; FD_LEN],
224 }
225
226 pub(super) fn stream_sendmsg<const FD_LEN: usize>(
227 stream: &UnixStream,
228 mut bytes: io::IoSlice<'_>,
229 fds: &[BorrowedFd<'_>; FD_LEN],
230 ) -> io::Result<()> {
231 let mut cmsg_buf = CMsgBuf {
232 header: cmsghdr {
233 cmsg_len: std::mem::size_of::<cmsghdr>() + FD_LEN * 4,
234 cmsg_level: SOL_SOCKET,
235 cmsg_type: SCM_RIGHTS,
236 },
237 fds: *fds,
238 };
239
240 let written_len = unsafe {
241 sendmsg(
242 stream.as_fd(),
243 &msghdr {
244 msg_name: ptr::null_mut(),
245 msg_namelen: 0,
246 msg_iov: &mut bytes,
247 msg_iovlen: 1,
248 msg_control: &mut cmsg_buf as *mut _ as *mut _,
249 msg_controllen: std::mem::size_of_val(&cmsg_buf),
250 msg_flags: 0,
251 },
252 0,
253 )
254 };
255 if written_len == -1 {
256 return Err(io::Error::last_os_error());
257 }
258 if written_len as usize != bytes.len() {
259 return Err(io_error_other(format!(
260 "partial write (only {written_len} out of {})",
261 bytes.len()
262 )));
263 }
264 Ok(())
265 }
266
267 pub(super) fn stream_recvmsg<const FD_LEN: usize>(
268 stream: &UnixStream,
269 mut bytes: io::IoSliceMut<'_>,
270 ) -> io::Result<[OwnedFd; FD_LEN]> {
271 let expected_len = bytes.len();
272
273 let mut cmsg_buf = std::mem::MaybeUninit::<CMsgBuf<Option<OwnedFd>, FD_LEN>>::zeroed();
274 let expected_cmsg_len = std::mem::size_of::<cmsghdr>() + FD_LEN * 4;
275 let expected_msg_controllen = std::mem::size_of_val(&cmsg_buf);
276
277 let mut msg = msghdr {
278 msg_name: ptr::null_mut(),
279 msg_namelen: 0,
280 msg_iov: &mut bytes,
281 msg_iovlen: 1,
282 msg_control: &mut cmsg_buf as *mut _ as *mut _,
283 msg_controllen: expected_msg_controllen,
284 msg_flags: 0,
285 };
286
287 let read_len = unsafe { recvmsg(stream.as_fd(), &mut msg, 0) };
288 if read_len == -1 {
289 return Err(io::Error::last_os_error());
290 }
291
292 if read_len as usize != expected_len {
296 return Err(io_error_other(format!(
297 "partial read: only {read_len} out of {expected_len}"
298 )));
299 }
300
301 if msg.msg_controllen != expected_msg_controllen {
302 return Err(io_error_other(format!(
303 "recvmsg msg_controllen mismatch: got {}, expected {expected_msg_controllen}",
304 msg.msg_controllen,
305 )));
306 }
307
308 let cmsg = unsafe { cmsg_buf.assume_init() };
309 if cmsg.header.cmsg_len != expected_cmsg_len {
310 return Err(io_error_other(format!(
311 "recvmsg cmsg_len mismatch: got {}, expected {expected_cmsg_len}",
312 cmsg.header.cmsg_len
313 )));
314 }
315
316 if (cmsg.header.cmsg_level, cmsg.header.cmsg_type) != (SOL_SOCKET, SCM_RIGHTS) {
317 return Err(io_error_other(format!("unsupported non-SCM_RIGHTS CMSG")));
318 }
319
320 if cmsg.fds.iter().any(|fd| fd.is_none()) {
321 return Err(io_error_other(format!("recvmsg got invalid (-1) fds")));
322 }
323
324 Ok(cmsg.fds.map(Option::unwrap))
325 }
326}
327
328impl<TX, RX> Channel<TX, RX> {
329 pub fn send<const TX_BYTE_LEN: usize, const TX_FD_LEN: usize>(&self, msg: TX) -> io::Result<()>
330 where
331 TX: FixedSizeEncoding<TX_BYTE_LEN, TX_FD_LEN>,
332 {
333 assert_ne!(
334 TX_FD_LEN,
335 0,
336 "Channel<{}, _> unsupported (lacks file descriptors)",
337 std::any::type_name::<TX>()
338 );
339
340 let (bytes, fds) = msg.encode();
341 sys::stream_sendmsg(&self.stream, io::IoSlice::new(&bytes), &fds)
342 }
343
344 pub fn recv<const RX_BYTE_LEN: usize, const RX_FD_LEN: usize>(&self) -> io::Result<RX>
345 where
346 RX: FixedSizeEncoding<RX_BYTE_LEN, RX_FD_LEN>,
347 {
348 assert_ne!(
349 RX_FD_LEN,
350 0,
351 "Channel<_, {}> unsupported (lacks file descriptors)",
352 std::any::type_name::<TX>()
353 );
354
355 let mut bytes = [0; RX_BYTE_LEN];
357 let fds = sys::stream_recvmsg(&self.stream, io::IoSliceMut::new(&mut bytes))?;
358 Ok(RX::decode(bytes, fds))
359 }
360
361 pub fn into_child_process_inheritable(self) -> io::Result<InheritableChannel<TX, RX>> {
366 extern "C" {
367 fn dup(fd: BorrowedFd<'_>) -> Option<OwnedFd>;
368 }
369 Ok(InheritableChannel(Self {
370 stream: unsafe { dup(self.stream.as_fd()) }
371 .ok_or_else(|| io::Error::last_os_error())?
372 .into(),
373 _marker: PhantomData,
374 }))
375 }
376}
377
378pub struct InheritableChannel<TX, RX>(Channel<TX, RX>);
383
384impl<TX, RX> AsFd for InheritableChannel<TX, RX> {
385 fn as_fd(&self) -> BorrowedFd<'_> {
386 self.0.stream.as_fd()
387 }
388}
389
390impl<TX, RX> From<OwnedFd> for InheritableChannel<TX, RX> {
391 fn from(fd: OwnedFd) -> Self {
392 Self(Channel {
393 stream: UnixStream::from(fd),
394 _marker: PhantomData,
395 })
396 }
397}
398
399impl<TX, RX> InheritableChannel<TX, RX> {
400 pub fn into_uninheritable(self) -> io::Result<Channel<TX, RX>> {
403 let Self(mut channel) = self;
404 channel.stream = channel.stream.as_fd().try_clone_to_owned()?.into();
405 Ok(channel)
406 }
407}
408
409pub enum Never {}
412
413pub trait FixedSizeEncoding<const BYTE_LEN: usize, const FD_LEN: usize> {
422 const BYTE_LEN: usize = BYTE_LEN;
424 const FD_LEN: usize = FD_LEN;
425
426 fn encode(&self) -> ([u8; BYTE_LEN], [BorrowedFd<'_>; FD_LEN]);
427 fn decode(bytes: [u8; BYTE_LEN], fds: [OwnedFd; FD_LEN]) -> Self;
428}
429
430impl<
432 const BYTE_LEN: usize,
433 const FD_LEN: usize,
434 A: FixedSizeEncoding<BYTE_LEN, 0>,
435 B: FixedSizeEncoding<0, FD_LEN>,
436 > FixedSizeEncoding<BYTE_LEN, FD_LEN> for (A, B)
437{
438 fn encode(&self) -> ([u8; BYTE_LEN], [BorrowedFd<'_>; FD_LEN]) {
439 let ((bytes, []), ([], fds)) = (self.0.encode(), self.1.encode());
440 (bytes, fds)
441 }
442 fn decode(bytes: [u8; BYTE_LEN], fds: [OwnedFd; FD_LEN]) -> Self {
443 (A::decode(bytes, []), B::decode([], fds))
444 }
445}
446
447macro_rules! fixed_size_le_prim_impls {
448 ($($ty:ident)*) => {
449 $(impl FixedSizeEncoding<{(Self::BITS / 8) as usize}, 0> for $ty {
450 fn encode(&self) -> ([u8; Self::BYTE_LEN], [BorrowedFd<'_>; 0]) {
451 (self.to_le_bytes(), [])
452 }
453 fn decode(bytes: [u8; Self::BYTE_LEN], []: [OwnedFd; 0]) -> Self {
454 Self::from_le_bytes(bytes)
455 }
456 })*
457 }
458}
459fixed_size_le_prim_impls!(u16 u32 u64 u128);
460
461impl FixedSizeEncoding<0, 1> for OwnedFd {
462 fn encode(&self) -> ([u8; 0], [BorrowedFd<'_>; 1]) {
463 ([], [self.as_fd()])
464 }
465 fn decode([]: [u8; 0], [fd]: [OwnedFd; 1]) -> Self {
466 fd
467 }
468}