#![allow(unsafe_code)]
use std::os::fd::RawFd;
use std::os::unix::process::CommandExt;
use std::process::Command;
use crate::role::{ENV_LISTEN_FDNAMES, ENV_LISTEN_FDS};
pub fn pass_listener_fds_on_spawn(
cmd: &mut Command,
listeners: &[(String, RawFd)],
extra_fd: Option<RawFd>,
) {
let names: Vec<String> = listeners.iter().map(|(n, _)| n.clone()).collect();
cmd.env(ENV_LISTEN_FDS, listeners.len().to_string());
cmd.env(ENV_LISTEN_FDNAMES, names.join(":"));
let mut fds: Vec<RawFd> = listeners.iter().map(|(_, f)| *f).collect();
if let Some(extra) = extra_fd {
fds.push(extra);
}
arrange_inherited_fds_on_spawn(cmd, fds);
}
pub fn arrange_inherited_fds_on_spawn(cmd: &mut Command, sources: Vec<RawFd>) {
let mut working = sources;
unsafe {
cmd.pre_exec(move || install_inherited_fds(&mut working));
}
}
fn install_inherited_fds(working: &mut [RawFd]) -> std::io::Result<()> {
let n = working.len();
if n == 0 {
return Ok(());
}
let staging_min: RawFd = 3 + (n as RawFd);
for slot in working.iter_mut() {
let new_fd = unsafe { libc::fcntl(*slot, libc::F_DUPFD, staging_min) };
if new_fd == -1 {
return Err(std::io::Error::last_os_error());
}
*slot = new_fd;
}
for (i, staged_fd) in working.iter().enumerate() {
let dst = 3 + i as RawFd;
if unsafe { libc::dup2(*staged_fd, dst) } == -1 {
return Err(std::io::Error::last_os_error());
}
}
for staged_fd in working.iter() {
unsafe { libc::close(*staged_fd) };
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_sources_is_noop() {
let mut empty: Vec<RawFd> = Vec::new();
install_inherited_fds(&mut empty).unwrap();
}
#[test]
fn install_in_forked_child_settles_targets_and_clears_cloexec() {
use nix::sys::socket::{AddressFamily, SockFlag, SockType, socketpair};
use nix::sys::wait::{WaitStatus, waitpid};
use nix::unistd::{ForkResult, fork};
use std::os::fd::IntoRawFd;
let mk = || {
let (a, b) = socketpair(
AddressFamily::Unix,
SockType::Stream,
None,
SockFlag::SOCK_CLOEXEC,
)
.unwrap();
(a.into_raw_fd(), b.into_raw_fd())
};
let (s0, _peer0) = mk();
let (s1, _peer1) = mk();
let (s2, _peer2) = mk();
match unsafe { fork() }.expect("fork") {
ForkResult::Child => {
let mut working = vec![s0, s1, s2];
let code = if install_inherited_fds(&mut working).is_err() {
1
} else {
let mut all_good = true;
for i in 0..3 {
let dst = 3 + i as RawFd;
let flags = unsafe { libc::fcntl(dst, libc::F_GETFD) };
if flags < 0 || (flags & libc::FD_CLOEXEC) != 0 {
all_good = false;
break;
}
}
if all_good { 0 } else { 2 }
};
unsafe { libc::_exit(code) };
}
ForkResult::Parent { child } => {
let status = waitpid(child, None).unwrap();
assert!(
matches!(status, WaitStatus::Exited(_, 0)),
"child reported failure: {status:?}"
);
}
}
}
}