use std::ffi::OsStr;
use std::future::Future;
use std::os::windows::ffi::OsStrExt;
use std::os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::{io, ptr};
use windows_sys::Win32::Foundation::{CloseHandle, HANDLE};
use windows_sys::Win32::System::Console::HPCON;
use windows_sys::Win32::System::JobObjects::{
AssignProcessToJobObject, CreateJobObjectW, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobObjectExtendedLimitInformation,
SetInformationJobObject, TerminateJobObject,
};
use windows_sys::Win32::System::Threading::{
CREATE_UNICODE_ENVIRONMENT, CreateProcessW, EXTENDED_STARTUPINFO_PRESENT, GetExitCodeProcess,
INFINITE, InitializeProcThreadAttributeList, PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
PROCESS_INFORMATION, STARTUPINFOEXW, UpdateProcThreadAttribute, WaitForSingleObject,
};
type BOOL = i32;
const FALSE: BOOL = 0;
#[allow(dead_code)]
const TRUE: BOOL = 1;
const WAIT_OBJECT_0: u32 = 0;
use crate::config::{PtyConfig, PtySignal};
use crate::error::{PtyError, Result};
use crate::traits::{ExitStatus, PtyChild};
#[derive(Debug)]
pub struct WindowsPtyChild {
process: OwnedHandle,
pid: u32,
job: Option<OwnedHandle>,
running: Arc<AtomicBool>,
exit_status: Option<ExitStatus>,
}
impl WindowsPtyChild {
pub fn new(process: OwnedHandle, pid: u32, job: Option<OwnedHandle>) -> Self {
Self {
process,
pid,
job,
running: Arc::new(AtomicBool::new(true)),
exit_status: None,
}
}
#[must_use]
pub fn pid(&self) -> u32 {
self.pid
}
#[must_use]
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub async fn wait(&mut self) -> Result<ExitStatus> {
if let Some(status) = self.exit_status {
return Ok(status);
}
let handle_val = self.process.as_raw_handle() as usize;
let exit_code = tokio::task::spawn_blocking(move || {
let handle = handle_val as HANDLE;
let wait_result = unsafe { WaitForSingleObject(handle, INFINITE) };
if wait_result != WAIT_OBJECT_0 {
return Err(io::Error::last_os_error());
}
let mut exit_code: u32 = 0;
if unsafe { GetExitCodeProcess(handle, &mut exit_code) } == FALSE {
return Err(io::Error::last_os_error());
}
Ok(exit_code)
})
.await
.map_err(|e| PtyError::Wait(io::Error::other(e)))?
.map_err(PtyError::Wait)?;
let status = ExitStatus::Terminated(exit_code);
self.exit_status = Some(status);
self.running.store(false, Ordering::SeqCst);
Ok(status)
}
pub fn try_wait(&mut self) -> Result<Option<ExitStatus>> {
if let Some(status) = self.exit_status {
return Ok(Some(status));
}
let handle = self.process.as_raw_handle() as HANDLE;
let wait_result = unsafe { WaitForSingleObject(handle, 0) };
if wait_result == WAIT_OBJECT_0 {
let mut exit_code: u32 = 0;
if unsafe { GetExitCodeProcess(handle, &mut exit_code) } == FALSE {
return Err(PtyError::Wait(io::Error::last_os_error()));
}
let status = ExitStatus::Terminated(exit_code);
self.exit_status = Some(status);
self.running.store(false, Ordering::SeqCst);
Ok(Some(status))
} else {
Ok(None)
}
}
pub fn signal(&self, signal: PtySignal) -> Result<()> {
use windows_sys::Win32::System::Console::{
CTRL_BREAK_EVENT, CTRL_C_EVENT, GenerateConsoleCtrlEvent,
};
match signal {
PtySignal::Interrupt => {
if unsafe { GenerateConsoleCtrlEvent(CTRL_C_EVENT, self.pid) } == FALSE {
Err(PtyError::Signal(io::Error::last_os_error()))
} else {
Ok(())
}
}
PtySignal::Quit => {
if unsafe { GenerateConsoleCtrlEvent(CTRL_BREAK_EVENT, self.pid) } == FALSE {
Err(PtyError::Signal(io::Error::last_os_error()))
} else {
Ok(())
}
}
PtySignal::Terminate | PtySignal::Kill | PtySignal::Hangup => {
self.terminate_impl()
}
PtySignal::WindowChange => {
Ok(())
}
}
}
fn terminate_impl(&self) -> Result<()> {
use windows_sys::Win32::System::Threading::TerminateProcess;
if let Some(ref job) = self.job {
if unsafe { TerminateJobObject(job.as_raw_handle() as HANDLE, 1) } == FALSE {
return Err(PtyError::Signal(io::Error::last_os_error()));
}
} else {
if unsafe { TerminateProcess(self.process.as_raw_handle() as HANDLE, 1) } == FALSE {
return Err(PtyError::Signal(io::Error::last_os_error()));
}
}
Ok(())
}
pub fn kill(&mut self) -> Result<()> {
self.terminate_impl()?;
self.running.store(false, Ordering::SeqCst);
Ok(())
}
}
impl PtyChild for WindowsPtyChild {
fn pid(&self) -> u32 {
WindowsPtyChild::pid(self)
}
fn is_running(&self) -> bool {
WindowsPtyChild::is_running(self)
}
fn wait(&mut self) -> Pin<Box<dyn Future<Output = Result<ExitStatus>> + Send + '_>> {
Box::pin(WindowsPtyChild::wait(self))
}
fn try_wait(&mut self) -> Result<Option<ExitStatus>> {
WindowsPtyChild::try_wait(self)
}
fn signal(&self, signal: PtySignal) -> Result<()> {
WindowsPtyChild::signal(self, signal)
}
fn kill(&mut self) -> Result<()> {
WindowsPtyChild::kill(self)
}
}
pub fn spawn_child<S, I>(
hpc: HPCON,
program: S,
args: I,
config: &PtyConfig,
) -> Result<WindowsPtyChild>
where
S: AsRef<OsStr>,
I: IntoIterator,
I::Item: AsRef<OsStr>,
{
let mut cmdline = escape_argument(program.as_ref());
for arg in args {
cmdline.push(b' ' as u16);
cmdline.extend(escape_argument(arg.as_ref()));
}
cmdline.push(0);
let env_block = build_environment_block(&config.effective_env());
let working_dir = config.working_directory.as_ref().map(|p| {
let mut w = to_wide_string(p.as_os_str());
w.push(0);
w
});
let job = create_job_object()?;
let (startup_info, _attr_list) = create_startup_info(hpc)?;
let mut process_info: PROCESS_INFORMATION = unsafe { std::mem::zeroed() };
let result = unsafe {
CreateProcessW(
ptr::null(),
cmdline.as_mut_ptr(),
ptr::null(),
ptr::null(),
FALSE,
EXTENDED_STARTUPINFO_PRESENT | CREATE_UNICODE_ENVIRONMENT,
if env_block.is_empty() {
ptr::null()
} else {
env_block.as_ptr() as *const _
},
working_dir.as_ref().map_or(ptr::null(), |w| w.as_ptr()),
&startup_info.StartupInfo,
&mut process_info,
)
};
if result == FALSE {
return Err(PtyError::Spawn(io::Error::last_os_error()));
}
unsafe {
CloseHandle(process_info.hThread);
}
let process = unsafe { OwnedHandle::from_raw_handle(process_info.hProcess as RawHandle) };
if let Some(ref job_handle) = job {
unsafe {
AssignProcessToJobObject(
job_handle.as_raw_handle() as HANDLE,
process.as_raw_handle() as HANDLE,
);
}
}
Ok(WindowsPtyChild::new(process, process_info.dwProcessId, job))
}
fn to_wide_string(s: &OsStr) -> Vec<u16> {
s.encode_wide().collect()
}
fn escape_argument(arg: &OsStr) -> Vec<u16> {
let arg_str = arg.to_string_lossy();
let needs_quoting = arg_str.is_empty()
|| arg_str.contains(' ')
|| arg_str.contains('\t')
|| arg_str.contains('"')
|| arg_str.contains('\\');
if !needs_quoting {
return to_wide_string(arg);
}
let mut result = Vec::new();
result.push(b'"' as u16);
let chars: Vec<char> = arg_str.chars().collect();
let mut i = 0;
while i < chars.len() {
let c = chars[i];
if c == '\\' {
let mut num_backslashes = 0;
while i < chars.len() && chars[i] == '\\' {
num_backslashes += 1;
i += 1;
}
if i < chars.len() && chars[i] == '"' {
for _ in 0..(num_backslashes * 2) {
result.push(b'\\' as u16);
}
result.push(b'\\' as u16);
result.push(b'"' as u16);
i += 1;
} else if i >= chars.len() {
for _ in 0..(num_backslashes * 2) {
result.push(b'\\' as u16);
}
} else {
for _ in 0..num_backslashes {
result.push(b'\\' as u16);
}
}
} else if c == '"' {
result.push(b'\\' as u16);
result.push(b'"' as u16);
i += 1;
} else {
for code_unit in c.encode_utf16(&mut [0u16; 2]) {
result.push(*code_unit);
}
i += 1;
}
}
result.push(b'"' as u16); result
}
fn build_environment_block(
env: &std::collections::HashMap<std::ffi::OsString, std::ffi::OsString>,
) -> Vec<u16> {
let mut block = Vec::new();
for (key, value) in env {
block.extend(to_wide_string(key));
block.push(b'=' as u16);
block.extend(to_wide_string(value));
block.push(0);
}
block.push(0); block
}
fn create_job_object() -> Result<Option<OwnedHandle>> {
let job = unsafe { CreateJobObjectW(ptr::null(), ptr::null()) };
if job.is_null() {
return Ok(None);
}
let mut info: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = unsafe { std::mem::zeroed() };
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
let result = unsafe {
SetInformationJobObject(
job,
JobObjectExtendedLimitInformation,
&info as *const _ as *const _,
std::mem::size_of::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() as u32,
)
};
if result == FALSE {
unsafe {
CloseHandle(job);
}
return Ok(None);
}
Ok(Some(unsafe {
OwnedHandle::from_raw_handle(job as RawHandle)
}))
}
fn create_startup_info(hpc: HPCON) -> Result<(STARTUPINFOEXW, Vec<u8>)> {
let mut size: usize = 0;
unsafe {
InitializeProcThreadAttributeList(ptr::null_mut(), 1, 0, &mut size);
}
let mut attr_list = vec![0u8; size];
let result = unsafe {
InitializeProcThreadAttributeList(attr_list.as_mut_ptr() as *mut _, 1, 0, &mut size)
};
if result == FALSE {
return Err(PtyError::Spawn(io::Error::last_os_error()));
}
let result = unsafe {
UpdateProcThreadAttribute(
attr_list.as_mut_ptr() as *mut _,
0,
PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE as usize,
hpc as *mut _,
std::mem::size_of::<HPCON>(),
ptr::null_mut(),
ptr::null_mut(),
)
};
if result == FALSE {
return Err(PtyError::Spawn(io::Error::last_os_error()));
}
let mut startup_info: STARTUPINFOEXW = unsafe { std::mem::zeroed() };
startup_info.StartupInfo.cb = std::mem::size_of::<STARTUPINFOEXW>() as u32;
startup_info.lpAttributeList = attr_list.as_mut_ptr() as *mut _;
Ok((startup_info, attr_list))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wide_string_conversion() {
let s = OsStr::new("hello");
let wide = to_wide_string(s);
assert_eq!(
wide,
vec![
b'h' as u16,
b'e' as u16,
b'l' as u16,
b'l' as u16,
b'o' as u16
]
);
}
fn wide_to_string(wide: &[u16]) -> String {
String::from_utf16_lossy(wide)
}
#[test]
fn escape_simple_argument() {
let arg = OsStr::new("hello");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "hello");
}
#[test]
fn escape_argument_with_space() {
let arg = OsStr::new("hello world");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"hello world\"");
}
#[test]
fn escape_argument_with_tab() {
let arg = OsStr::new("hello\tworld");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"hello\tworld\"");
}
#[test]
fn escape_empty_argument() {
let arg = OsStr::new("");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"\"");
}
#[test]
fn escape_argument_with_quote() {
let arg = OsStr::new("say \"hello\"");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"say \\\"hello\\\"\"");
}
#[test]
fn escape_argument_with_backslash() {
let arg = OsStr::new("C:\\Users\\test");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"C:\\Users\\test\"");
}
#[test]
fn escape_argument_with_trailing_backslash() {
let arg = OsStr::new("C:\\Users\\");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"C:\\Users\\\\\"");
}
#[test]
fn escape_argument_with_multiple_trailing_backslashes() {
let arg = OsStr::new("path\\\\");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"path\\\\\\\\\"");
}
#[test]
fn escape_argument_backslash_before_quote() {
let arg = OsStr::new("test\\\"value");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"test\\\\\\\"value\"");
}
#[test]
fn escape_argument_multiple_backslashes_before_quote() {
let arg = OsStr::new("test\\\\\"value");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"test\\\\\\\\\\\"value\"");
}
#[test]
fn escape_complex_path() {
let arg = OsStr::new("C:\\Program Files\\My App\\bin");
let escaped = escape_argument(arg);
assert_eq!(
wide_to_string(&escaped),
"\"C:\\Program Files\\My App\\bin\""
);
}
#[test]
fn escape_unc_path() {
let arg = OsStr::new("\\\\server\\share\\folder");
let escaped = escape_argument(arg);
assert_eq!(wide_to_string(&escaped), "\"\\\\server\\share\\folder\"");
}
}