use nix::{
fcntl::{F_SETFD, FdFlag, fcntl},
libc,
};
use std::{
collections::HashMap,
fs::{canonicalize, read_dir},
os::fd::{FromRawFd, OwnedFd, RawFd},
sync::{Mutex, OnceLock},
};
use thiserror::Error;
static INHERITED_FDS: OnceLock<Mutex<HashMap<RawFd, Option<OwnedFd>>>> = OnceLock::new();
#[derive(Debug, PartialEq, Error)]
pub enum InheritedFdError {
#[error("init_inherited_fds() not called")]
NotInitialized,
#[error("Ownership of FD {0} is already taken")]
OwnershipTaken(RawFd),
#[error("FD {0} is either invalid file descriptor or not an inherited one")]
FileDescriptorNotInherited(RawFd),
}
pub unsafe fn init_inherited_fds() -> Result<(), std::io::Error> {
let mut fds = HashMap::new();
let fd_path = canonicalize("/proc/self/fd")?;
for entry in read_dir(&fd_path)? {
let entry = entry?;
let file_name = entry.file_name();
let raw_fd = file_name.to_str().unwrap().parse::<RawFd>().unwrap();
if [libc::STDIN_FILENO, libc::STDOUT_FILENO, libc::STDERR_FILENO].contains(&raw_fd) {
continue;
}
if entry.path().read_link()? == fd_path {
continue;
}
let owned_fd = unsafe { OwnedFd::from_raw_fd(raw_fd) };
fcntl(&owned_fd, F_SETFD(FdFlag::FD_CLOEXEC))?;
fds.insert(raw_fd, Some(owned_fd));
}
INHERITED_FDS
.set(Mutex::new(fds))
.or(Err(std::io::Error::other(
"Inherited fds were already initialized",
)))
}
pub fn take_fd_ownership(raw_fd: RawFd) -> Result<OwnedFd, InheritedFdError> {
let mut fds = INHERITED_FDS
.get()
.ok_or(InheritedFdError::NotInitialized)?
.lock()
.unwrap();
if let Some(value) = fds.get_mut(&raw_fd) {
if let Some(owned_fd) = value.take() {
Ok(owned_fd)
} else {
Err(InheritedFdError::OwnershipTaken(raw_fd))
}
} else {
Err(InheritedFdError::FileDescriptorNotInherited(raw_fd))
}
}
#[cfg(test)]
mod test {
use super::*;
use nix::unistd::close;
use std::{
io,
os::fd::{AsRawFd, IntoRawFd},
};
use tempfile::tempfile;
struct Fixture {
fds: Vec<RawFd>,
}
impl Fixture {
fn setup(num_fds: usize) -> Result<Self, io::Error> {
let mut fds = Vec::new();
for _ in 0..num_fds {
fds.push(tempfile()?.into_raw_fd());
}
Ok(Fixture { fds })
}
fn open_new_file(&mut self) -> Result<RawFd, io::Error> {
let raw_fd = tempfile()?.into_raw_fd();
self.fds.push(raw_fd);
Ok(raw_fd)
}
}
impl Drop for Fixture {
fn drop(&mut self) {
self.fds.iter().for_each(|fd| {
let _ = close(*fd);
});
}
}
fn is_fd_opened(raw_fd: RawFd) -> bool {
unsafe { libc::fcntl(raw_fd, libc::F_GETFD) != -1 }
}
#[test]
fn happy_case() {
let fixture = Fixture::setup(2).unwrap();
let f0 = fixture.fds[0];
let f1 = fixture.fds[1];
unsafe {
init_inherited_fds().unwrap();
}
let f0_owned = take_fd_ownership(f0).unwrap();
let f1_owned = take_fd_ownership(f1).unwrap();
assert_eq!(f0, f0_owned.as_raw_fd());
assert_eq!(f1, f1_owned.as_raw_fd());
drop(f0_owned);
drop(f1_owned);
assert!(!is_fd_opened(f0));
assert!(!is_fd_opened(f1));
}
#[test]
fn access_non_inherited_fd() {
let mut fixture = Fixture::setup(2).unwrap();
unsafe {
init_inherited_fds().unwrap();
}
let f = fixture.open_new_file().unwrap();
assert_eq!(
take_fd_ownership(f).err(),
Some(InheritedFdError::FileDescriptorNotInherited(f))
);
}
#[test]
fn call_init_inherited_fds_multiple_times() {
let _ = Fixture::setup(2).unwrap();
unsafe {
init_inherited_fds().unwrap();
}
let res = unsafe { init_inherited_fds() };
assert!(res.is_err());
}
#[test]
fn access_without_init_inherited_fds() {
let fixture = Fixture::setup(2).unwrap();
let f = fixture.fds[0];
assert_eq!(
take_fd_ownership(f).err(),
Some(InheritedFdError::NotInitialized)
);
}
#[test]
fn double_ownership() {
let fixture = Fixture::setup(2).unwrap();
let f = fixture.fds[0];
unsafe {
init_inherited_fds().unwrap();
}
let f_owned = take_fd_ownership(f).unwrap();
let f_double_owned = take_fd_ownership(f);
assert_eq!(
f_double_owned.err(),
Some(InheritedFdError::OwnershipTaken(f)),
);
drop(f_owned);
}
#[test]
fn take_drop_retake() {
let fixture = Fixture::setup(2).unwrap();
let f = fixture.fds[0];
unsafe {
init_inherited_fds().unwrap();
}
let f_owned = take_fd_ownership(f).unwrap();
drop(f_owned);
let f_double_owned = take_fd_ownership(f);
assert_eq!(
f_double_owned.err(),
Some(InheritedFdError::OwnershipTaken(f)),
);
}
#[test]
fn cloexec() {
let fixture = Fixture::setup(2).unwrap();
let f = fixture.fds[0];
let res = unsafe { libc::fcntl(f.as_raw_fd(), libc::F_SETFD, 0) };
assert_ne!(res, -1);
unsafe {
init_inherited_fds().unwrap();
}
let flags = unsafe { libc::fcntl(f.as_raw_fd(), libc::F_GETFD) };
assert_ne!(flags, -1);
assert_eq!(flags, FdFlag::FD_CLOEXEC.bits());
}
}