1use std::fs::File;
11use std::mem::size_of;
12use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
13use std::os::unix::net::{UnixDatagram, UnixStream};
14use std::ptr;
15
16#[cfg(target_os = "linux")]
17use libc::MSG_NOSIGNAL;
18use libc::{self, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, SCM_RIGHTS, SOL_SOCKET};
19
20use crate::common::{SysError, SysResult};
21
22macro_rules! CMSG_ALIGN {
26 ($len:expr) => {
27 (($len) as usize + ::std::mem::size_of::<libc::c_long>() - 1)
28 & !(::std::mem::size_of::<libc::c_long>() - 1)
29 };
30}
31
32macro_rules! CMSG_SPACE {
33 ($len:expr) => {
34 ::std::mem::size_of::<cmsghdr>() + CMSG_ALIGN!($len)
35 };
36}
37
38macro_rules! FD_LENGTH {
39 ($count:expr) => {
40 std::mem::size_of::<RawFd>() * $count
41 };
42}
43
44#[allow(non_snake_case)]
48#[inline(always)]
49fn CMSG_DATA(cmsg_buffer: *mut libc::cmsghdr) -> *mut RawFd {
50 cmsg_buffer.wrapping_offset(1) as *mut RawFd
52}
53
54#[cfg(not(target_env = "musl"))]
55fn new_msghdr(iovecs: &mut [libc::iovec]) -> libc::msghdr {
56 libc::msghdr {
57 msg_name: ptr::null_mut(),
58 msg_namelen: 0,
59 msg_iov: iovecs.as_mut_ptr(),
60 msg_iovlen: iovecs.len() as _,
61 msg_control: ptr::null_mut(),
62 msg_controllen: 0,
63 msg_flags: 0,
64 }
65}
66
67#[cfg(target_env = "musl")]
68fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
69 assert!(iovecs.len() <= (std::i32::MAX as usize));
70 let mut msg: msghdr = unsafe { std::mem::zeroed() };
71 msg.msg_name = ptr::null_mut();
72 msg.msg_iov = iovecs.as_mut_ptr();
73 msg.msg_iovlen = iovecs.len() as i32;
74 msg.msg_control = ptr::null_mut();
75 msg
76}
77
78#[cfg(not(target_env = "musl"))]
79fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
80 msg.msg_controllen = cmsg_capacity as _;
81}
82
83#[cfg(target_env = "musl")]
84fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
85 assert!(cmsg_capacity <= (std::u32::MAX as usize));
86 msg.msg_controllen = cmsg_capacity as u32;
87}
88
89const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
106
107impl CmsgBuffer {
108 fn with_capacity(capacity: usize) -> CmsgBuffer {
109 let cap_in_cmsghdr_units =
110 (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
111 if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
112 CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
113 } else {
114 CmsgBuffer::Heap(
115 vec![
116 cmsghdr {
117 cmsg_len: 0,
118 cmsg_level: 0,
119 cmsg_type: 0,
120 #[cfg(all(target_env = "musl", target_pointer_width = "64"))]
121 __pad1: 0,
122 };
123 cap_in_cmsghdr_units
124 ]
125 .into_boxed_slice(),
126 )
127 }
128 }
129
130 fn as_mut_ptr(&mut self) -> *mut libc::cmsghdr {
131 match self {
132 CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
133 CmsgBuffer::Heap(a) => a.as_mut_ptr(),
134 }
135 }
136}
137
138enum CmsgBuffer {
139 Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
140 Heap(Box<[cmsghdr]>),
141}
142
143fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> SysResult<usize> {
144 unsafe {
145 let fd_len = FD_LENGTH!(out_fds.len());
146 let cmsg_capacity = libc::CMSG_SPACE(fd_len as _);
148 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity as _);
149 let mut iovecs = Vec::with_capacity(out_data.len());
150 for data in out_data {
151 iovecs.push(iovec {
152 iov_base: data.as_ptr() as *mut c_void,
153 iov_len: data.size(),
154 });
155 }
156
157 let mut msghdr = new_msghdr(&mut iovecs);
158 if !out_fds.is_empty() {
159 let cmsg = libc::cmsghdr {
160 cmsg_len: libc::CMSG_LEN(fd_len as u32) as _,
161 cmsg_level: SOL_SOCKET,
162 cmsg_type: SCM_RIGHTS,
163 #[cfg(all(target_env = "musl", target_pointer_width = "64"))]
164 __pad1: 0,
165 };
166 ptr::write_unaligned(cmsg_buffer.as_mut_ptr(), cmsg);
168 ptr::copy_nonoverlapping(
171 out_fds.as_ptr(),
172 libc::CMSG_DATA(cmsg_buffer.as_mut_ptr()) as *mut _,
173 out_fds.len(),
174 );
175
176 msghdr.msg_control = cmsg_buffer.as_mut_ptr() as *mut _;
177 set_msg_controllen(&mut msghdr, cmsg_capacity as _);
178 }
179
180 #[cfg(target_os = "linux")]
184 let write_count = sendmsg(fd, &msghdr, MSG_NOSIGNAL);
185 #[cfg(target_os = "macos")]
186 let write_count = sendmsg(fd, &msghdr, 0);
187
188 if write_count == -1 {
189 Err(SysError::last())
190 } else {
191 Ok(write_count as usize)
192 }
193 }
194}
195
196unsafe fn raw_recvmsg(
250 fd: RawFd,
251 iovecs: &mut [iovec],
252 in_fds: &mut [RawFd],
253) -> SysResult<(usize, usize)> {
254 let fd_length = FD_LENGTH!(in_fds.len());
255 let cmsg_capacity = libc::CMSG_SPACE(fd_length as _) as usize;
256 let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
257 let mut msg = new_msghdr(iovecs);
258
259 if !in_fds.is_empty() {
260 msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
262 set_msg_controllen(&mut msg, cmsg_capacity);
263 }
264
265 let total_read = recvmsg(fd, &mut msg, 0);
269 if total_read == -1 {
270 return Err(SysError::last());
271 }
272
273 if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() {
274 return Ok((0, 0));
275 }
276
277 let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
280 let mut copied_fds_count = 0;
281 let mut teardown_control_data = msg.msg_flags & libc::MSG_CTRUNC != 0;
284
285 while !cmsg_ptr.is_null() {
286 let cmsg = (cmsg_ptr as *mut cmsghdr).read_unaligned();
290 if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
291 let fds_count =
295 (cmsg.cmsg_len as usize - libc::CMSG_LEN(0) as usize) / size_of::<RawFd>();
296 let fds_to_be_copied_count = std::cmp::min(in_fds.len() - copied_fds_count, fds_count);
300 teardown_control_data |= fds_count > fds_to_be_copied_count;
301 if teardown_control_data {
302 for fd_offset in 0..fds_count {
307 let raw_fds_ptr = CMSG_DATA(cmsg_ptr);
308 let raw_fd = *(raw_fds_ptr.wrapping_add(fd_offset)) as libc::c_int;
311 libc::close(raw_fd);
312 }
313 } else {
314 ptr::copy_nonoverlapping(
317 CMSG_DATA(cmsg_ptr),
318 in_fds[copied_fds_count..(copied_fds_count + fds_to_be_copied_count)]
319 .as_mut_ptr(),
320 fds_to_be_copied_count,
321 );
322
323 copied_fds_count += fds_to_be_copied_count;
324 }
325 }
326
327 if teardown_control_data {
329 for fd in in_fds.iter().take(copied_fds_count) {
330 libc::close(*fd);
333 }
334
335 return Err(SysError::new(libc::ENOBUFS));
336 }
337
338 cmsg_ptr = libc::CMSG_NXTHDR(&msg, cmsg_ptr); }
340
341 Ok((total_read as usize, copied_fds_count))
342}
343
344impl ScmSocket for UnixDatagram {
345 fn socket_fd(&self) -> RawFd {
346 self.as_raw_fd()
347 }
348}
349
350impl ScmSocket for UnixStream {
351 fn socket_fd(&self) -> RawFd {
352 self.as_raw_fd()
353 }
354}
355
356pub unsafe trait IntoIovec {
364 fn as_ptr(&self) -> *const libc::c_void;
366
367 fn size(&self) -> usize;
369}
370
371unsafe impl<'a> IntoIovec for &'a [u8] {
374 #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))]
376 fn as_ptr(&self) -> *const libc::c_void {
377 self.as_ref().as_ptr() as *const libc::c_void
378 }
379
380 fn size(&self) -> usize {
381 self.len()
382 }
383}
384
385pub trait ScmSocket {
388 fn socket_fd(&self) -> RawFd;
390
391 fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> SysResult<usize> {
400 self.send_with_fds(&[buf], &[fd])
401 }
402
403 fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> SysResult<usize> {
412 raw_sendmsg(self.socket_fd(), bufs, fds)
413 }
414
415 fn recv_with_fd(&self, buf: &mut [u8]) -> SysResult<(usize, Option<File>)> {
423 let mut fd = [0];
424 let mut iovecs = [libc::iovec {
425 iov_base: buf.as_mut_ptr() as *mut libc::c_void,
426 iov_len: buf.len(),
427 }];
428
429 let (read_count, fd_count) = unsafe { self.recv_with_fds(&mut iovecs[..], &mut fd)? };
432 let file = if fd_count == 0 {
433 None
434 } else {
435 Some(unsafe { File::from_raw_fd(fd[0]) })
438 };
439 Ok((read_count, file))
440 }
441
442 unsafe fn recv_with_fds(
461 &self,
462 iovecs: &mut [libc::iovec],
463 fds: &mut [RawFd],
464 ) -> SysResult<(usize, usize)> {
465 raw_recvmsg(self.socket_fd(), iovecs, fds)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 use std::io::{Read, Seek, SeekFrom, Write};
474 use std::mem::size_of;
475 use std::os::raw::c_long;
476 use std::os::unix::net::UnixDatagram;
477
478 use libc::cmsghdr;
479 use vmm_sys_util::tempfile::TempFile;
480
481 #[test]
482 fn buffer_len() {
483 assert_eq!(CMSG_SPACE!(0), size_of::<cmsghdr>());
484 assert_eq!(
485 CMSG_SPACE!(size_of::<RawFd>()),
486 size_of::<cmsghdr>() + size_of::<c_long>()
487 );
488 if size_of::<RawFd>() == 4 {
489 assert_eq!(
490 CMSG_SPACE!(2 * size_of::<RawFd>()),
491 size_of::<cmsghdr>() + size_of::<c_long>()
492 );
493 assert_eq!(
494 CMSG_SPACE!(3 * size_of::<RawFd>()),
495 size_of::<cmsghdr>() + size_of::<c_long>() * 2
496 );
497 assert_eq!(
498 CMSG_SPACE!(4 * size_of::<RawFd>()),
499 size_of::<cmsghdr>() + size_of::<c_long>() * 2
500 );
501 } else if size_of::<RawFd>() == 8 {
502 assert_eq!(
503 CMSG_SPACE!(2 * size_of::<RawFd>()),
504 size_of::<cmsghdr>() + size_of::<c_long>() * 2
505 );
506 assert_eq!(
507 CMSG_SPACE!(3 * size_of::<RawFd>()),
508 size_of::<cmsghdr>() + size_of::<c_long>() * 3
509 );
510 assert_eq!(
511 CMSG_SPACE!(4 * size_of::<RawFd>()),
512 size_of::<cmsghdr>() + size_of::<c_long>() * 4
513 );
514 }
515 }
516
517 #[test]
518 fn send_recv_no_fd() {
519 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
520
521 let write_count = s1
522 .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[])
523 .expect("failed to send data");
524
525 assert_eq!(write_count, 6);
526
527 let mut buf = [0u8; 6];
528 let mut files = [0; 1];
529 let mut iovecs = [iovec {
530 iov_base: buf.as_mut_ptr() as *mut c_void,
531 iov_len: buf.len(),
532 }];
533 let (read_count, file_count) = unsafe {
534 s2.recv_with_fds(&mut iovecs[..], &mut files)
535 .expect("failed to recv data")
536 };
537
538 assert_eq!(read_count, 6);
539 assert_eq!(file_count, 0);
540 assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
541 }
542
543 #[test]
544 fn send_recv_only_fd() {
545 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
546
547 let mut file1 = TempFile::new().unwrap().into_file();
548 file1.write_all(b"foo").unwrap();
549 file1.seek(SeekFrom::Start(0)).unwrap();
550
551 let write_count = s1
552 .send_with_fd([].as_ref(), file1.as_raw_fd())
553 .expect("failed to send fd");
554
555 assert_eq!(write_count, 0);
556
557 let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
558
559 let mut file = file_opt.unwrap();
560
561 assert_eq!(read_count, 0);
562 assert!(file.as_raw_fd() >= 0);
563 assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
564 assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
565 assert_ne!(file.as_raw_fd(), file1.as_raw_fd());
566
567 let mut buf = String::new();
568 file.read_to_string(&mut buf).unwrap();
569 assert_eq!("foo".to_string(), buf);
570 }
571
572 #[test]
573 fn send_recv_with_fd() {
574 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
575
576 let mut file1 = TempFile::new().unwrap().into_file();
577 file1.write_all(b"foo").unwrap();
578 file1.seek(SeekFrom::Start(0)).unwrap();
579
580 let write_count = s1
581 .send_with_fds(&[[237].as_ref()], &[file1.as_raw_fd()])
582 .expect("failed to send fd");
583
584 assert_eq!(write_count, 1);
585
586 let mut files = [0; 2];
587 let mut buf = [0u8];
588 let mut iovecs = [iovec {
589 iov_base: buf.as_mut_ptr() as *mut c_void,
590 iov_len: buf.len(),
591 }];
592 let (read_count, file_count) = unsafe {
593 s2.recv_with_fds(&mut iovecs[..], &mut files)
594 .expect("failed to recv fd")
595 };
596
597 assert_eq!(read_count, 1);
598 assert_eq!(buf[0], 237);
599 assert_eq!(file_count, 1);
600 assert!(files[0] >= 0);
601 assert_ne!(files[0], s1.as_raw_fd());
602 assert_ne!(files[0], s2.as_raw_fd());
603 assert_ne!(files[0], file1.as_raw_fd());
604
605 let mut file = unsafe { File::from_raw_fd(files[0]) };
606 let mut buf = String::new();
607 file.read_to_string(&mut buf).unwrap();
608 assert_eq!("foo".to_string(), buf);
609 assert_ne!("bar".to_string(), buf);
610 }
611
612 #[test]
613 fn send_more_recv_less1() {
616 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
617
618 let mut file1 = TempFile::new().unwrap().into_file();
619 let mut file2 = TempFile::new().unwrap().into_file();
620 let mut file3 = TempFile::new().unwrap().into_file();
621 let mut file4 = TempFile::new().unwrap().into_file();
622 file1.write_all(b"foo").unwrap();
623 file1.seek(SeekFrom::Start(0)).unwrap();
624 file2.write_all(b"bar").unwrap();
625 file2.seek(SeekFrom::Start(0)).unwrap();
626 file3.write_all(b"foobar").unwrap();
627 file3.seek(SeekFrom::Start(0)).unwrap();
628 file4.write_all(b"foobarfoo").unwrap();
629 file4.seek(SeekFrom::Start(0)).unwrap();
630 let write_count = s1
631 .send_with_fds(
632 &[[237].as_ref()],
633 &[
634 file1.as_raw_fd(),
635 file2.as_raw_fd(),
636 file3.as_raw_fd(),
637 file4.as_raw_fd(),
638 ],
639 )
640 .expect("failed to send fd");
641
642 assert_eq!(write_count, 1);
643
644 let mut files = [0; 2];
645 let mut buf = [0u8];
646 let mut iovecs = [iovec {
647 iov_base: buf.as_mut_ptr() as *mut c_void,
648 iov_len: buf.len(),
649 }];
650 assert!(unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).is_err() });
651 }
652
653 #[test]
656 fn send_more_recv_less2() {
657 let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
658
659 let mut file1 = TempFile::new().unwrap().into_file();
660 let mut file2 = TempFile::new().unwrap().into_file();
661 let mut file3 = TempFile::new().unwrap().into_file();
662 let mut file4 = TempFile::new().unwrap().into_file();
663 file1.write_all(b"foo").unwrap();
664 file1.seek(SeekFrom::Start(0)).unwrap();
665 file2.write_all(b"bar").unwrap();
666 file2.seek(SeekFrom::Start(0)).unwrap();
667 file3.write_all(b"foobar").unwrap();
668 file3.seek(SeekFrom::Start(0)).unwrap();
669 file4.write_all(b"foobarfoo").unwrap();
670 file4.seek(SeekFrom::Start(0)).unwrap();
671 let write_count = s1
672 .send_with_fds(
673 &[[237].as_ref()],
674 &[
675 file1.as_raw_fd(),
676 file2.as_raw_fd(),
677 file3.as_raw_fd(),
678 file4.as_raw_fd(),
679 ],
680 )
681 .expect("failed to send fd");
682
683 assert_eq!(write_count, 1);
684
685 let mut files = [0; 1];
686 let mut buf = [0u8];
687 let mut iovecs = [iovec {
688 iov_base: buf.as_mut_ptr() as *mut c_void,
689 iov_len: buf.len(),
690 }];
691 assert!(unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).is_err() });
692 }
693}