use std::{
collections::BTreeMap,
ffi::{CStr, CString, OsStr, OsString},
io::{Error, ErrorKind, Result},
mem::MaybeUninit,
os::unix::prelude::{OsStrExt, OsStringExt},
path::Path,
ptr::null_mut,
};
use libc::{pid_t, sigemptyset, signal};
use nix::{
sys::memfd::{memfd_create, MemFdCreateFlag},
unistd::{close, fexecve, write},
};
use crate::{
anon_pipe::anon_pipe,
child::Child,
command_env::CommandEnv,
cvt::{cvt, cvt_nz, cvt_r},
output::Output,
process::{ExitStatus, Process},
stdio::{ChildPipes, Stdio, StdioPipes},
};
#[derive(Debug)]
pub struct MemFdExecutable<'a> {
code: &'a [u8],
program: CString,
args: Vec<CString>,
argv: Argv,
env: CommandEnv,
cwd: Option<CString>,
pub stdin: Option<Stdio>,
pub stdout: Option<Stdio>,
pub stderr: Option<Stdio>,
saw_nul: bool,
}
#[derive(Debug)]
struct Argv(Vec<CString>);
unsafe impl Send for Argv {}
unsafe impl Sync for Argv {}
fn os2c(s: &OsStr, saw_nul: &mut bool) -> CString {
CString::new(s.as_bytes()).unwrap_or_else(|_e| {
*saw_nul = true;
CString::new("<string-with-nul>").unwrap()
})
}
fn construct_envp(env: BTreeMap<OsString, OsString>, saw_nul: &mut bool) -> Vec<CString> {
let mut result = Vec::with_capacity(env.len());
for (mut k, v) in env {
k.reserve_exact(v.len() + 2);
k.push("=");
k.push(&v);
if let Ok(item) = CString::new(k.into_vec()) {
result.push(item);
} else {
*saw_nul = true;
}
}
result
}
impl<'a> MemFdExecutable<'a> {
pub fn new<S: AsRef<OsStr>>(name: S, code: &'a [u8]) -> Self {
let mut saw_nul = false;
let name = os2c(name.as_ref(), &mut saw_nul);
Self {
code,
program: name.clone(),
args: vec![name.clone()],
argv: Argv(vec![name]),
env: Default::default(),
cwd: None,
stdin: None,
stdout: None,
stderr: None,
saw_nul,
}
}
pub fn arg<S: AsRef<OsStr>>(&mut self, arg: S) -> &mut Self {
let arg = os2c(arg.as_ref(), &mut self.saw_nul);
self.argv.0.push(arg.clone());
self.args.push(arg);
self
}
pub fn args<I, S>(&mut self, args: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
for arg in args {
self.arg(arg.as_ref());
}
self
}
pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
where
K: AsRef<OsStr>,
V: AsRef<OsStr>,
{
self.env_mut().set(key.as_ref(), val.as_ref());
self
}
pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<OsStr>,
V: AsRef<OsStr>,
{
for (ref key, ref val) in vars {
self.env_mut().set(key.as_ref(), val.as_ref());
}
self
}
pub fn env_remove<K: AsRef<OsStr>>(&mut self, key: K) -> &mut Self {
self.env_mut().remove(key.as_ref());
self
}
pub fn env_clear(&mut self) -> &mut Self {
self.env_mut().clear();
self
}
pub fn cwd<P: AsRef<Path>>(&mut self, dir: P) -> &mut Self {
self.cwd = Some(os2c(dir.as_ref().as_ref(), &mut self.saw_nul));
self
}
pub fn stdin<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
self.stdin = Some(cfg.into());
self
}
pub fn stdout<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
self.stdout = Some(cfg.into());
self
}
pub fn stderr<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
self.stderr = Some(cfg.into());
self
}
pub fn spawn(&mut self) -> Result<Child> {
let default = Stdio::Inherit;
let needs_stdin = true;
const CLOEXEC_MSG_FOOTER: [u8; 4] = *b"NOEX";
let envp = self.capture_env();
if self.saw_nul() {
}
let (ours, theirs) = self.setup_io(default, needs_stdin)?;
let (input, output) = anon_pipe()?;
let pid = unsafe { self.do_fork()? };
if pid == 0 {
drop(input);
let Err(err) = (unsafe { self.do_exec(theirs, envp) }) else { unreachable!("..."); };
panic!("failed to exec: {}", err);
}
drop(output);
let mut p = unsafe { Process::new(pid) };
let mut bytes = [0; 8];
loop {
match input.read(&mut bytes) {
Ok(0) => return Ok(Child::new(p, ours)),
Ok(8) => {
let (errno, footer) = bytes.split_at(4);
assert_eq!(
CLOEXEC_MSG_FOOTER, footer,
"Validation on the CLOEXEC pipe failed: {:?}",
bytes
);
let errno = i32::from_be_bytes(errno.try_into().unwrap());
assert!(p.wait().is_ok(), "wait() should either return Ok or panic");
return Err(Error::from_raw_os_error(errno));
}
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => {
assert!(p.wait().is_ok(), "wait() should either return Ok or panic");
panic!("the CLOEXEC pipe failed: {e:?}")
}
Ok(..) => {
assert!(p.wait().is_ok(), "wait() should either return Ok or panic");
panic!("short read on the CLOEXEC pipe")
}
}
}
}
pub fn output(&mut self) -> Result<Output> {
self.spawn()?.wait_with_output()
}
pub fn status(&mut self) -> Result<ExitStatus> {
self.spawn()?.wait()
}
pub fn set_program(&mut self, program: &OsStr) {
let arg = os2c(program, &mut self.saw_nul);
self.argv.0[0] = arg.clone();
self.args[0] = arg;
}
fn env_mut(&mut self) -> &mut CommandEnv {
&mut self.env
}
fn setup_io(&self, default: Stdio, needs_stdin: bool) -> Result<(StdioPipes, ChildPipes)> {
let null = Stdio::Null;
let default_stdin = if needs_stdin { &default } else { &null };
let stdin = self.stdin.as_ref().unwrap_or(default_stdin);
let stdout = self.stdout.as_ref().unwrap_or(&default);
let stderr = self.stderr.as_ref().unwrap_or(&default);
let (their_stdin, our_stdin) = stdin.to_child_stdio(true)?;
let (their_stdout, our_stdout) = stdout.to_child_stdio(false)?;
let (their_stderr, our_stderr) = stderr.to_child_stdio(false)?;
let ours = StdioPipes {
stdin: our_stdin,
stdout: our_stdout,
stderr: our_stderr,
};
let theirs = ChildPipes {
stdin: their_stdin,
stdout: their_stdout,
stderr: their_stderr,
};
Ok((ours, theirs))
}
fn saw_nul(&self) -> bool {
self.saw_nul
}
pub fn get_cwd(&self) -> &Option<CString> {
&self.cwd
}
unsafe fn do_fork(&mut self) -> Result<pid_t> {
cvt(libc::fork())
}
fn capture_env(&mut self) -> Option<Vec<CString>> {
let maybe_env = self.env.capture_if_changed();
maybe_env.map(|env| construct_envp(env, &mut self.saw_nul))
}
pub fn exec(&mut self, default: Stdio) -> Error {
let envp = self.capture_env();
if self.saw_nul() {
return Error::new(ErrorKind::InvalidInput, "nul byte found in provided data");
}
match self.setup_io(default, true) {
Ok((_, theirs)) => unsafe {
let Err(e) = self.do_exec(theirs, envp) else { unreachable!("..."); };
e
},
Err(e) => e,
}
}
pub fn get_program_cstr(&self) -> &CStr {
&self.program
}
pub fn get_argv(&self) -> &Vec<CString> {
&self.argv.0
}
pub fn env_saw_path(&self) -> bool {
self.env.have_changed_path()
}
pub fn program_is_path(&self) -> bool {
self.program.to_bytes().contains(&b'/')
}
unsafe fn do_exec(
&mut self,
stdio: ChildPipes,
maybe_envp: Option<Vec<CString>>,
) -> Result<()> {
if let Some(fd) = stdio.stdin.fd() {
cvt_r(|| libc::dup2(fd, libc::STDIN_FILENO))?;
}
if let Some(fd) = stdio.stdout.fd() {
cvt_r(|| libc::dup2(fd, libc::STDOUT_FILENO))?;
}
if let Some(fd) = stdio.stderr.fd() {
cvt_r(|| libc::dup2(fd, libc::STDERR_FILENO))?;
}
if let Some(ref cwd) = *self.get_cwd() {
cvt(libc::chdir(cwd.as_ptr()))?;
}
{
let mut set = MaybeUninit::<libc::sigset_t>::uninit();
cvt(sigemptyset(set.as_mut_ptr()))?;
cvt_nz(libc::pthread_sigmask(
libc::SIG_SETMASK,
set.as_ptr(),
null_mut(),
))?;
{
let ret = signal(libc::SIGPIPE, libc::SIG_DFL);
if ret == libc::SIG_ERR {
return Err(Error::last_os_error());
}
}
}
let mfd = memfd_create(
CString::new("rust_exec").unwrap().as_c_str(),
MemFdCreateFlag::MFD_CLOEXEC,
)
.unwrap();
if let Ok(n) = write(mfd, self.code) {
if n != self.code.len() {
return Err(Error::new(
ErrorKind::BrokenPipe,
"Failed to write to memfd",
));
}
} else {
return Err(Error::last_os_error());
}
let argv = self
.get_argv()
.iter()
.map(|s| s.as_c_str())
.collect::<Vec<_>>();
let maybe_envp = maybe_envp.unwrap_or_default();
let envp = maybe_envp.iter().map(|s| s.as_c_str()).collect::<Vec<_>>();
if let Err(err) = fexecve(mfd, &argv, &envp) {
let _ = close(mfd);
return Err(Error::new(ErrorKind::BrokenPipe, err));
}
Err(Error::last_os_error())
}
}