use cap_std::fs::Dir;
use cap_std::io_lifetimes;
use cap_tempfile::cap_std;
use io_lifetimes::OwnedFd;
use rustix::fd::{AsFd, FromRawFd, IntoRawFd};
use rustix::io::FdFlags;
use std::collections::BTreeSet;
use std::ffi::CString;
use std::os::fd::AsRawFd;
use std::os::unix::process::CommandExt;
use std::sync::Arc;
const SD_LISTEN_FDS_START: i32 = 3;
#[derive(Debug, Clone, Copy)]
pub struct SystemdFdName<'a>(&'a str);
impl<'a> SystemdFdName<'a> {
pub const fn new(name: &'a str) -> Self {
assert!(
name.len() <= 255,
"systemd fd name must be at most 255 characters"
);
let bytes = name.as_bytes();
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
assert!(
b >= b' ' && b < 127 && b != b':',
"systemd fd name must only contain printable ASCII characters except ':'"
);
i += 1;
}
Self(name)
}
pub fn as_str(&self) -> &'a str {
self.0
}
}
#[derive(Debug)]
pub struct CmdFds {
taken: BTreeSet<i32>,
fds: Vec<(i32, Arc<OwnedFd>)>,
systemd_env: Option<(CString, CString)>,
}
impl Default for CmdFds {
fn default() -> Self {
Self::new()
}
}
impl CmdFds {
pub fn new() -> Self {
Self {
taken: BTreeSet::new(),
fds: Vec::new(),
systemd_env: None,
}
}
pub fn new_systemd_fds<'a>(
fds: impl IntoIterator<Item = (Arc<OwnedFd>, SystemdFdName<'a>)>,
) -> Self {
let mut this = Self::new();
this.register_systemd_fds(fds);
this
}
fn next_fd(&self) -> i32 {
self.taken
.last()
.map(|n| n.checked_add(1).expect("fd number overflow"))
.unwrap_or(SD_LISTEN_FDS_START)
}
fn insert_fd(&mut self, n: i32) {
let inserted = self.taken.insert(n);
assert!(inserted, "fd {n} is already assigned");
}
pub fn take_fd(&mut self, fd: Arc<OwnedFd>) -> i32 {
let n = self.next_fd();
self.insert_fd(n);
self.fds.push((n, fd));
n
}
pub fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self {
self.insert_fd(target);
self.fds.push((target, fd));
self
}
fn register_systemd_fds<'a>(
&mut self,
fds: impl IntoIterator<Item = (Arc<OwnedFd>, SystemdFdName<'a>)>,
) {
let mut n_fds: i32 = 0;
let mut names = Vec::new();
for (fd, name) in fds {
let target = SD_LISTEN_FDS_START
.checked_add(n_fds)
.expect("too many fds");
self.insert_fd(target);
self.fds.push((target, fd));
names.push(name.as_str());
n_fds = n_fds.checked_add(1).expect("too many fds");
}
let fd_count = CString::new(n_fds.to_string()).unwrap();
let fd_names = CString::new(names.join(":")).unwrap();
self.systemd_env = Some((fd_count, fd_names));
}
}
pub trait CapStdExtCommandExt {
#[deprecated = "Use CmdFds with take_fds() instead"]
fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self;
fn take_fds(&mut self, fds: CmdFds) -> &mut Self;
fn cwd_dir(&mut self, dir: Dir) -> &mut Self;
#[cfg(any(target_os = "linux", target_os = "android"))]
fn lifecycle_bind_to_parent_thread(&mut self) -> &mut Self;
}
#[allow(unsafe_code)]
unsafe fn check_setenv(
key: *const std::ffi::c_char,
val: *const std::ffi::c_char,
) -> std::io::Result<()> {
if unsafe { libc::setenv(key, val, 1) } != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
#[allow(unsafe_code)]
#[allow(deprecated)]
impl CapStdExtCommandExt for std::process::Command {
fn take_fd_n(&mut self, fd: Arc<OwnedFd>, target: i32) -> &mut Self {
unsafe {
self.pre_exec(move || {
let mut target = OwnedFd::from_raw_fd(target);
if target.as_raw_fd() == fd.as_raw_fd() {
let fl = rustix::io::fcntl_getfd(&target)?;
rustix::io::fcntl_setfd(&mut target, fl.difference(FdFlags::CLOEXEC))?;
} else {
rustix::io::dup2(&*fd, &mut target)?;
}
let _ = target.into_raw_fd();
Ok(())
});
}
self
}
fn take_fds(&mut self, fds: CmdFds) -> &mut Self {
unsafe {
self.pre_exec(move || {
let safe_min = fds
.fds
.iter()
.map(|(t, _)| *t)
.max()
.unwrap_or(0)
.checked_add(1)
.expect("fd number overflow");
let mut safe_copies: Vec<(i32, OwnedFd)> = Vec::new();
for (target, fd) in &fds.fds {
let copy = rustix::io::fcntl_dupfd_cloexec(fd, safe_min)?;
safe_copies.push((*target, copy));
}
for (target, copy) in safe_copies {
let r = libc::dup2(copy.as_raw_fd(), target);
if r < 0 {
return Err(std::io::Error::last_os_error());
}
}
if let Some((ref fd_count, ref fd_names)) = fds.systemd_env {
let pid = rustix::process::getpid();
let pid_dec = rustix::path::DecInt::new(pid.as_raw_nonzero().get());
check_setenv(c"LISTEN_PID".as_ptr(), pid_dec.as_c_str().as_ptr())?;
check_setenv(c"LISTEN_FDS".as_ptr(), fd_count.as_ptr())?;
check_setenv(c"LISTEN_FDNAMES".as_ptr(), fd_names.as_ptr())?;
}
Ok(())
});
}
self
}
fn cwd_dir(&mut self, dir: Dir) -> &mut Self {
unsafe {
self.pre_exec(move || {
rustix::process::fchdir(dir.as_fd())?;
Ok(())
});
}
self
}
#[cfg(any(target_os = "linux", target_os = "android"))]
fn lifecycle_bind_to_parent_thread(&mut self) -> &mut Self {
unsafe {
self.pre_exec(|| {
rustix::process::set_parent_process_death_signal(Some(
rustix::process::Signal::TERM,
))
.map_err(Into::into)
});
}
self
}
}
#[cfg(all(test, any(target_os = "android", target_os = "linux")))]
mod tests {
use super::*;
use std::sync::Arc;
#[allow(deprecated)]
#[test]
fn test_take_fdn() -> anyhow::Result<()> {
for i in 0..=1 {
let tempd = cap_tempfile::TempDir::new(cap_std::ambient_authority())?;
let tempd_fd = Arc::new(tempd.as_fd().try_clone_to_owned()?);
let n = tempd_fd.as_raw_fd() + i;
#[cfg(any(target_os = "android", target_os = "linux"))]
let path = format!("/proc/self/fd/{n}");
#[cfg(not(any(target_os = "android", target_os = "linux")))]
let path = format!("/dev/fd/{n}");
let st = std::process::Command::new("/usr/bin/env")
.arg("readlink")
.arg(path)
.take_fd_n(tempd_fd, n)
.status()?;
assert!(st.success());
}
Ok(())
}
}