#![allow(non_snake_case)]
#![deny(missing_docs)]
mod binding;
use std::{
env,
ffi::{c_void, OsStr, OsString},
fmt,
io::Error,
iter::once,
mem::size_of,
os::windows::ffi::OsStrExt,
path::{Path, PathBuf},
ptr::{null, null_mut},
};
use crate::binding::{
CloseHandle, CreateProcessW, GetExitCodeProcess, TerminateProcess, WaitForSingleObject, BOOL,
CREATE_UNICODE_ENVIRONMENT, DWORD, INFINITE, PCWSTR, PDWORD, PROCESS_INFORMATION, PWSTR,
SECURITY_ATTRIBUTES, STARTUPINFOW, STATUS_PENDING, UINT, WAIT_OBJECT_0,
};
#[derive(Debug)]
pub struct Command {
command: OsString,
inherit_handles: bool,
current_directory: Option<PathBuf>,
env_clear: bool,
env_vars: Vec<(OsString, Option<OsString>)>,
}
impl Command {
pub fn new(command: impl Into<OsString>) -> Self {
Self {
command: command.into(),
inherit_handles: false,
current_directory: None,
env_clear: false,
env_vars: Vec::new(),
}
}
pub fn inherit_handles(&mut self, inherit: bool) -> &mut Self {
self.inherit_handles = inherit;
self
}
pub fn current_dir(&mut self, dir: impl Into<PathBuf>) -> &mut Self {
self.current_directory = Some(dir.into());
self
}
pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
where
K: AsRef<OsStr>,
V: AsRef<OsStr>,
{
self.env_vars.push((
key.as_ref().to_os_string(),
Some(val.as_ref().to_os_string()),
));
self
}
pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<OsStr>,
V: AsRef<OsStr>,
{
for (key, val) in vars {
self.env(key, val);
}
self
}
pub fn env_remove<K>(&mut self, key: K) -> &mut Self
where
K: AsRef<OsStr>,
{
self.env_vars.push((key.as_ref().to_os_string(), None));
self
}
pub fn env_clear(&mut self) -> &mut Self {
self.env_clear = true;
self.env_vars.clear();
self
}
pub fn spawn(&mut self) -> Result<Child, Error> {
Child::new(
&self.command,
self.inherit_handles,
self.current_directory.as_deref(),
self.env_clear,
std::mem::take(&mut self.env_vars),
)
}
pub fn status(&mut self) -> Result<ExitStatus, Error> {
self.spawn()?.wait()
}
}
#[derive(Debug)]
pub struct Child {
process_information: PROCESS_INFORMATION,
}
impl Child {
fn new(
command: &OsStr,
inherit_handles: bool,
current_directory: Option<&Path>,
env_clear: bool,
env_vars: Vec<(OsString, Option<OsString>)>,
) -> Result<Self, Error> {
let mut startup_information = STARTUPINFOW::default();
let mut process_information = PROCESS_INFORMATION::default();
startup_information.cb = size_of::<STARTUPINFOW>() as u32;
let env_block = build_env_block(env_clear, env_vars);
let lp_env_ptr = env_block
.as_ref()
.map(|b| b.as_ptr() as *mut c_void)
.unwrap_or(null_mut());
let process_creation_flags = if lp_env_ptr.is_null() {
0
} else {
CREATE_UNICODE_ENVIRONMENT
};
let mut security_attributes;
let (lp_process_attributes, lp_thread_attributes) = if inherit_handles {
security_attributes = SECURITY_ATTRIBUTES::new(true);
(
&mut security_attributes as *mut SECURITY_ATTRIBUTES,
&mut security_attributes as *mut SECURITY_ATTRIBUTES,
)
} else {
(null_mut(), null_mut())
};
let current_directory_ptr = current_directory
.map(|path| {
let wide_path: Vec<u16> = path.as_os_str().encode_wide().chain(once(0)).collect();
wide_path.as_ptr()
})
.unwrap_or(std::ptr::null_mut());
let command = command.encode_wide().chain(once(0)).collect::<Vec<_>>();
let res = unsafe {
CreateProcessW(
null(),
command.as_ptr() as PWSTR,
lp_process_attributes,
lp_thread_attributes,
inherit_handles as BOOL,
process_creation_flags as DWORD,
lp_env_ptr,
current_directory_ptr as PCWSTR,
&startup_information,
&mut process_information,
)
};
if res != 0 {
Ok(Self {
process_information,
})
} else {
Err(Error::last_os_error())
}
}
pub fn kill(&self) -> Result<(), Error> {
let res = unsafe { TerminateProcess(self.process_information.hProcess, 0 as UINT) };
if res != 0 {
Ok(())
} else {
Err(Error::last_os_error())
}
}
pub fn wait(&self) -> Result<ExitStatus, Error> {
let mut exit_code = 0;
let wait = unsafe {
WaitForSingleObject(self.process_information.hProcess, INFINITE) == WAIT_OBJECT_0
};
if wait {
let res = unsafe {
GetExitCodeProcess(self.process_information.hProcess, &mut exit_code as PDWORD)
};
if res != 0 {
unsafe {
CloseHandle(self.process_information.hProcess);
CloseHandle(self.process_information.hThread);
}
Ok(ExitStatus(exit_code))
} else {
Err(Error::last_os_error())
}
} else {
Err(Error::last_os_error())
}
}
pub fn try_wait(&self) -> Result<Option<ExitStatus>, Error> {
let mut exit_code = 0;
let res = unsafe {
GetExitCodeProcess(self.process_information.hProcess, &mut exit_code as PDWORD)
};
if res != 0 {
if exit_code == STATUS_PENDING {
Ok(None)
} else {
unsafe {
CloseHandle(self.process_information.hProcess);
CloseHandle(self.process_information.hThread);
}
Ok(Some(ExitStatus(exit_code)))
}
} else {
Err(Error::last_os_error())
}
}
pub fn id(&self) -> u32 {
self.process_information.dwProcessId
}
}
fn build_env_block(
env_clear: bool,
env_vars: Vec<(OsString, Option<OsString>)>,
) -> Option<Vec<u16>> {
fn ascii_lower_wide(s: &OsStr) -> impl Iterator<Item = u16> + '_ {
s.encode_wide().map(|c| {
if (b'A' as u16..=b'Z' as u16).contains(&c) {
c + 32
} else {
c
}
})
}
fn eq_ignore_ascii_case(a: &OsStr, b: &OsStr) -> bool {
ascii_lower_wide(a).eq(ascii_lower_wide(b))
}
if !env_clear && env_vars.is_empty() {
return None;
}
let mut map: Vec<(OsString, OsString)> = if env_clear {
Vec::new()
} else {
env::vars_os().collect()
};
let mut seen: Vec<OsString> = Vec::new();
for (key, val) in env_vars.into_iter().rev() {
if seen.iter().any(|k| eq_ignore_ascii_case(k, &key)) {
continue;
}
seen.push(key.clone());
map.retain(|(k, _)| !eq_ignore_ascii_case(k, &key));
if let Some(val) = val {
map.push((key, val));
}
}
let mut pairs: Vec<_> = map
.drain(..)
.map(|(key, val)| {
let lowered: Vec<u16> = ascii_lower_wide(&key).collect();
(lowered, key, val)
})
.collect();
pairs.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
map = pairs.into_iter().map(|(_, key, val)| (key, val)).collect();
let mut block: Vec<u16> = Vec::new();
for (key, val) in &map {
block.extend(key.encode_wide());
block.push(b'=' as u16);
block.extend(val.encode_wide());
block.push(0);
}
if map.is_empty() {
block.extend(&[0, 0]);
} else {
block.push(0);
}
Some(block)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ExitStatus(u32);
impl ExitStatus {
pub fn success(&self) -> bool {
self.0 == 0
}
pub fn code(&self) -> u32 {
self.0
}
}
impl fmt::Display for ExitStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::binding::{GetHandleInformation, HANDLE_FLAG_INHERIT};
unsafe fn is_inheritable(handle: *mut c_void) -> bool {
let mut flags: DWORD = 0;
let ok = GetHandleInformation(handle, &mut flags) != 0;
assert!(ok, "GetHandleInformation failed");
(flags & HANDLE_FLAG_INHERIT) != 0
}
#[test]
fn default_spawn_does_not_give_inheritable_handles() {
let child = Command::new("cmd.exe /c exit 0").spawn().unwrap();
let hproc = child.process_information.hProcess;
let hthread = child.process_information.hThread;
let proc_inheritable = unsafe { is_inheritable(hproc) };
let thread_inheritable = unsafe { is_inheritable(hthread) };
child.wait().unwrap();
assert!(
!proc_inheritable,
"process handle should NOT be inheritable by default"
);
assert!(
!thread_inheritable,
"thread handle should NOT be inheritable by default"
);
}
#[test]
fn inherit_handles_true_gives_inheritable_handles() {
let child = Command::new("cmd.exe /c exit 0")
.inherit_handles(true)
.spawn()
.unwrap();
let hproc = child.process_information.hProcess;
let hthread = child.process_information.hThread;
let proc_inheritable = unsafe { is_inheritable(hproc) };
let thread_inheritable = unsafe { is_inheritable(hthread) };
child.wait().unwrap();
assert!(
proc_inheritable,
"process handle should be inheritable when inherit_handles(true)"
);
assert!(
thread_inheritable,
"thread handle should be inheritable when inherit_handles(true)"
);
}
#[test]
fn env_var_is_passed_to_child() {
let child =
Command::new(r#"cmd.exe /c "if "%MY_VAR%"=="hello_test" (exit 0) else (exit 1)""#)
.env("MY_VAR", "hello_test")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(status.code(), 0, "MY_VAR should be 'hello_test'");
}
#[test]
fn env_clear_with_single_var() {
let child = Command::new(
r#"cmd.exe /c "if defined PATH (exit 1) else (if "%CUSTOM%"=="value" (exit 0) else (exit 2))""#,
)
.env_clear()
.env("CUSTOM", "value")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(
status.code(),
0,
"PATH should be unset and CUSTOM should be 'value'"
);
}
#[test]
fn env_remove_removes_var() {
let child = Command::new("cmd.exe /c \"if defined PATH (exit 1) else (exit 0)\"")
.env_remove("PATH")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(status.code(), 0);
}
#[test]
fn no_env_args_inherits_parent() {
let child = Command::new("cmd.exe /c \"if defined PATH (exit 0) else (exit 1)\"")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(status.code(), 0);
}
#[test]
fn last_duplicate_key_wins() {
let child = Command::new(
r#"cmd.exe /c "if "%MY_VAR%"=="second" (exit 2) else (if "%MY_VAR%"=="first" (exit 0) else (exit 3))""#,
)
.env("MY_VAR", "first")
.env("MY_VAR", "second")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(
status.code(),
2,
"duplicate key should keep the last value 'second'"
);
}
#[test]
fn last_duplicate_key_wins_case_insensitive() {
let child = Command::new(
r#"cmd.exe /c "if "%MYVAR%"=="second" (exit 2) else (if "%MYVAR%"=="first" (exit 0) else (exit 3))""#,
)
.env("MyVar", "first")
.env("MYVAR", "second")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(
status.code(),
2,
"case-insensitive duplicate key should keep the last value 'second'"
);
}
#[test]
fn env_overrides_earlier_remove() {
let child =
Command::new(r#"cmd.exe /c "if "%MY_VAR%"=="hello_override" (exit 0) else (exit 1)""#)
.env_remove("MY_VAR")
.env("MY_VAR", "hello_override")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(status.code(), 0, "env should override earlier env_remove");
}
#[test]
fn remove_overrides_earlier_env() {
let child = Command::new(r#"cmd.exe /c "if defined MY_VAR (exit 1) else (exit 0)""#)
.env("MY_VAR", "some_value")
.env_remove("MY_VAR")
.spawn()
.unwrap();
let status = child.wait().unwrap();
assert_eq!(status.code(), 0, "env_remove should override earlier env");
}
}