pub mod inherited;
#[cfg(feature = "tokio")]
pub mod tokio;
use nix::fcntl::{FcntlArg, FdFlag, fcntl};
use nix::unistd::dup2_raw;
use std::cmp::max;
use std::io;
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd};
use std::os::unix::io::RawFd;
use std::os::unix::process::CommandExt;
use std::process::Command;
use thiserror::Error;
#[derive(Debug)]
pub struct FdMapping {
pub parent_fd: OwnedFd,
pub child_fd: RawFd,
}
#[derive(Copy, Clone, Debug, Eq, Error, PartialEq)]
#[error("Two or more mappings for the same child FD")]
pub struct FdMappingCollision;
pub trait CommandFdExt {
fn fd_mappings(&mut self, mappings: Vec<FdMapping>) -> Result<&mut Self, FdMappingCollision>;
fn preserved_fds(&mut self, fds: Vec<OwnedFd>) -> &mut Self;
}
impl CommandFdExt for Command {
fn fd_mappings(
&mut self,
mut mappings: Vec<FdMapping>,
) -> Result<&mut Self, FdMappingCollision> {
let child_fds = validate_child_fds(&mappings)?;
unsafe {
self.pre_exec(move || map_fds(&mut mappings, &child_fds));
}
Ok(self)
}
fn preserved_fds(&mut self, fds: Vec<OwnedFd>) -> &mut Self {
unsafe {
self.pre_exec(move || preserve_fds(&fds));
}
self
}
}
fn validate_child_fds(mappings: &[FdMapping]) -> Result<Vec<RawFd>, FdMappingCollision> {
let mut child_fds: Vec<RawFd> = mappings.iter().map(|mapping| mapping.child_fd).collect();
child_fds.sort_unstable();
child_fds.dedup();
if child_fds.len() != mappings.len() {
return Err(FdMappingCollision);
}
Ok(child_fds)
}
fn map_fds(mappings: &mut [FdMapping], child_fds: &[RawFd]) -> io::Result<()> {
if mappings.is_empty() {
return Ok(());
}
let first_safe_fd = mappings
.iter()
.map(|mapping| max(mapping.parent_fd.as_raw_fd(), mapping.child_fd))
.max()
.unwrap()
+ 1;
for mapping in mappings.iter_mut() {
if child_fds.contains(&mapping.parent_fd.as_raw_fd())
&& mapping.parent_fd.as_raw_fd() != mapping.child_fd
{
let parent_fd = fcntl(&mapping.parent_fd, FcntlArg::F_DUPFD_CLOEXEC(first_safe_fd))?;
unsafe {
mapping.parent_fd = OwnedFd::from_raw_fd(parent_fd);
}
}
}
for mapping in mappings {
if mapping.child_fd == mapping.parent_fd.as_raw_fd() {
fcntl(&mapping.parent_fd, FcntlArg::F_SETFD(FdFlag::empty()))?;
} else {
unsafe {
let _ = dup2_raw(&mapping.parent_fd, mapping.child_fd)?.into_raw_fd();
}
}
}
Ok(())
}
fn preserve_fds(fds: &[OwnedFd]) -> io::Result<()> {
for fd in fds {
fcntl(fd, FcntlArg::F_SETFD(FdFlag::empty()))?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::close;
use std::collections::HashSet;
use std::fs::{File, read_dir};
use std::os::unix::io::AsRawFd;
use std::process::Output;
use std::str;
use std::sync::Once;
static SETUP: Once = Once::new();
#[test]
fn conflicting_mappings() {
setup();
let mut command = Command::new("ls");
let file1 = File::open("testdata/file1.txt").unwrap();
let file2 = File::open("testdata/file2.txt").unwrap();
assert!(
command
.fd_mappings(vec![
FdMapping {
child_fd: 4,
parent_fd: file1.into(),
},
FdMapping {
child_fd: 4,
parent_fd: file2.into(),
},
])
.is_err()
);
}
#[test]
fn no_mappings() {
setup();
let mut command = Command::new("ls");
command.arg("/proc/self/fd");
assert!(command.fd_mappings(vec![]).is_ok());
let output = command.output().unwrap();
expect_fds(&output, &[0, 1, 2, 3], 0);
}
#[test]
fn none_preserved() {
setup();
let mut command = Command::new("ls");
command.arg("/proc/self/fd");
command.preserved_fds(vec![]);
let output = command.output().unwrap();
expect_fds(&output, &[0, 1, 2, 3], 0);
}
#[test]
fn one_mapping() {
setup();
let mut command = Command::new("ls");
command.arg("/proc/self/fd");
let file = File::open("testdata/file1.txt").unwrap();
assert!(
command
.fd_mappings(vec![FdMapping {
parent_fd: file.into(),
child_fd: 5,
},])
.is_ok()
);
let output = command.output().unwrap();
expect_fds(&output, &[0, 1, 2, 3, 5], 0);
}
#[test]
#[ignore = "flaky on GitHub"]
fn one_preserved() {
setup();
let mut command = Command::new("ls");
command.arg("/proc/self/fd");
let file = File::open("testdata/file1.txt").unwrap();
let file_fd: OwnedFd = file.into();
let raw_file_fd = file_fd.as_raw_fd();
assert!(raw_file_fd > 3);
command.preserved_fds(vec![file_fd]);
let output = command.output().unwrap();
expect_fds(&output, &[0, 1, 2, 3, raw_file_fd], 0);
}
#[test]
fn swap_mappings() {
setup();
let mut command = Command::new("ls");
command.arg("/proc/self/fd");
let file1 = File::open("testdata/file1.txt").unwrap();
let file2 = File::open("testdata/file2.txt").unwrap();
let fd1: OwnedFd = file1.into();
let fd2: OwnedFd = file2.into();
let fd1_raw = fd1.as_raw_fd();
let fd2_raw = fd2.as_raw_fd();
assert!(
command
.fd_mappings(vec![
FdMapping {
parent_fd: fd1,
child_fd: fd2_raw,
},
FdMapping {
parent_fd: fd2,
child_fd: fd1_raw,
},
])
.is_ok(),
);
let output = command.output().unwrap();
expect_fds(&output, &[0, 1, 2, fd1_raw, fd2_raw], 1);
}
#[test]
fn one_to_one_mapping() {
setup();
let mut command = Command::new("ls");
command.arg("/proc/self/fd");
let file1 = File::open("testdata/file1.txt").unwrap();
let file2 = File::open("testdata/file2.txt").unwrap();
let fd1: OwnedFd = file1.into();
let fd1_raw = fd1.as_raw_fd();
assert!(
command
.fd_mappings(vec![FdMapping {
parent_fd: fd1,
child_fd: fd1_raw,
}])
.is_ok()
);
let output = command.output().unwrap();
expect_fds(&output, &[0, 1, 2, fd1_raw], 1);
drop(file2);
}
#[test]
fn map_stdin() {
setup();
let mut command = Command::new("cat");
let file = File::open("testdata/file1.txt").unwrap();
assert!(
command
.fd_mappings(vec![FdMapping {
parent_fd: file.into(),
child_fd: 0,
},])
.is_ok()
);
let output = command.output().unwrap();
assert!(output.status.success());
assert_eq!(output.stdout, b"test 1");
}
fn parse_ls_output(output: &[u8]) -> HashSet<String> {
str::from_utf8(output)
.unwrap()
.split_terminator("\n")
.map(str::to_owned)
.collect()
}
fn expect_fds(output: &Output, expected_fds: &[RawFd], extra: usize) {
assert!(output.status.success());
let expected_fds: HashSet<String> = expected_fds.iter().map(RawFd::to_string).collect();
let fds = parse_ls_output(&output.stdout);
if extra == 0 {
assert_eq!(fds, expected_fds);
} else {
assert!(expected_fds.is_subset(&fds));
assert_eq!(fds.len(), expected_fds.len() + extra);
}
}
fn setup() {
SETUP.call_once(close_excess_fds);
}
fn close_excess_fds() {
let dir = read_dir("/proc/self/fd").unwrap();
for entry in dir {
let entry = entry.unwrap();
let fd: RawFd = entry.file_name().to_str().unwrap().parse().unwrap();
if fd > 3 {
close(fd).unwrap();
}
}
}
}