#![allow(unsafe_op_in_unsafe_fn)]
use anyhow::{Result, anyhow};
use std::collections::HashMap;
use std::path::Path;
use std::ptr;
use windows_sys::Win32::Foundation::{
CloseHandle, GetLastError, HANDLE, HANDLE_FLAG_INHERIT, SetHandleInformation,
};
use windows_sys::Win32::System::Pipes::CreatePipe;
use windows_sys::Win32::System::Threading::{
CreateProcessAsUserW, GetExitCodeProcess, INFINITE, PROCESS_INFORMATION,
STARTF_USESTDHANDLES, STARTUPINFOW, TerminateProcess,
WaitForSingleObject, CREATE_UNICODE_ENVIRONMENT,
};
use windows_sys::Win32::Storage::FileSystem::ReadFile;
use crate::winutil::to_wide;
#[derive(Debug, Clone)]
pub struct ProcessResult {
pub exit_code: u32,
pub stdout: Vec<u8>,
pub stderr: Vec<u8>,
pub timed_out: bool,
}
fn make_env_block(env: &HashMap<String, String>) -> Vec<u16> {
use std::os::windows::ffi::OsStrExt;
let mut items: Vec<(String, String)> =
env.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
items.sort_by(|a, b| {
a.0.to_uppercase()
.cmp(&b.0.to_uppercase())
.then(a.0.cmp(&b.0))
});
let mut w: Vec<u16> = Vec::new();
for (k, v) in &items {
let entry = format!("{k}={v}");
w.extend(
std::ffi::OsStr::new(&entry)
.encode_wide()
);
w.push(0); }
if items.is_empty() {
w.push(0);
}
w.push(0);
w
}
fn argv_to_command_line(args: &[String]) -> String {
args.iter()
.map(|arg| {
if !arg.contains(' ') && !arg.contains('\t') && !arg.contains('"') && !arg.ends_with('\\')
{
return arg.clone();
}
let mut result = String::from("\"");
let bytes = arg.as_bytes();
let mut i = 0;
while i < bytes.len() {
let mut backslash_count = 0usize;
while i < bytes.len() && bytes[i] == b'\\' {
backslash_count += 1;
i += 1;
}
if i == bytes.len() {
for _ in 0..backslash_count * 2 {
result.push('\\');
}
} else if bytes[i] == b'"' {
for _ in 0..backslash_count * 2 + 1 {
result.push('\\');
}
result.push('"');
i += 1;
} else {
for _ in 0..backslash_count {
result.push('\\');
}
result.push(bytes[i] as char);
i += 1;
}
}
result.push('"');
result
})
.collect::<Vec<_>>()
.join(" ")
}
unsafe fn create_pipes() -> Result<((HANDLE, HANDLE), (HANDLE, HANDLE), (HANDLE, HANDLE))> {
let mut in_r: HANDLE = 0;
let mut in_w: HANDLE = 0;
let mut out_r: HANDLE = 0;
let mut out_w: HANDLE = 0;
let mut err_r: HANDLE = 0;
let mut err_w: HANDLE = 0;
if CreatePipe(&mut in_r, &mut in_w, ptr::null_mut(), 0) == 0 {
return Err(anyhow!("CreatePipe stdin failed: {}", GetLastError()));
}
if CreatePipe(&mut out_r, &mut out_w, ptr::null_mut(), 0) == 0 {
CloseHandle(in_r); CloseHandle(in_w);
return Err(anyhow!("CreatePipe stdout failed: {}", GetLastError()));
}
if CreatePipe(&mut err_r, &mut err_w, ptr::null_mut(), 0) == 0 {
CloseHandle(in_r); CloseHandle(in_w);
CloseHandle(out_r); CloseHandle(out_w);
return Err(anyhow!("CreatePipe stderr failed: {}", GetLastError()));
}
Ok(((in_r, in_w), (out_r, out_w), (err_r, err_w)))
}
pub unsafe fn execute_with_token(
h_token: HANDLE,
command: &[String],
cwd: &Path,
env_map: &HashMap<String, String>,
timeout_ms: Option<u64>,
desktop_name: Option<&str>,
process_handle_out: Option<&std::sync::Arc<std::sync::Mutex<Option<HANDLE>>>>,
) -> Result<ProcessResult> {
if command.is_empty() {
return Err(anyhow!("Empty command"));
}
let ((in_r, in_w), (out_r, out_w), (err_r, err_w)) = create_pipes()?;
let cmdline_str = argv_to_command_line(command);
let mut cmdline = to_wide(&cmdline_str);
let env_block = if env_map.is_empty() {
let current_env: HashMap<String, String> = std::env::vars().collect();
make_env_block(¤t_env)
} else {
make_env_block(env_map)
};
let cwd_wide = to_wide(cwd);
let desktop_str = desktop_name.unwrap_or("Winsta0\\Default");
let desktop_name_wide = to_wide(desktop_str);
let mut si: STARTUPINFOW = std::mem::zeroed();
si.cb = std::mem::size_of::<STARTUPINFOW>() as u32;
si.lpDesktop = desktop_name_wide.as_ptr() as *mut u16; si.dwFlags = STARTF_USESTDHANDLES;
si.hStdInput = in_r;
si.hStdOutput = out_w;
si.hStdError = err_w;
for &handle in &[in_r, out_w, err_w] {
if SetHandleInformation(handle, HANDLE_FLAG_INHERIT, HANDLE_FLAG_INHERIT) == 0 {
CloseHandle(in_r); CloseHandle(in_w);
CloseHandle(out_r); CloseHandle(out_w);
CloseHandle(err_r); CloseHandle(err_w);
return Err(anyhow!("SetHandleInformation failed: {}", GetLastError()));
}
}
let mut pi: PROCESS_INFORMATION = std::mem::zeroed();
let ok = CreateProcessAsUserW(
h_token,
ptr::null(), cmdline.as_mut_ptr(),
ptr::null_mut(), ptr::null_mut(), 1, CREATE_UNICODE_ENVIRONMENT,
env_block.as_ptr() as *mut _,
cwd_wide.as_ptr(),
&si,
&mut pi,
);
CloseHandle(in_r);
CloseHandle(in_w);
CloseHandle(out_w);
CloseHandle(err_w);
if ok == 0 {
let err = GetLastError();
CloseHandle(out_r);
CloseHandle(err_r);
let hint = match err {
2 => {
let first_cmd = command.first().map(|s| s.as_str()).unwrap_or("");
format!(
"\n 💡 提示: '{first_cmd}' 不是可执行文件(cmd 内部命令需要用 cmd /c 前缀)\n \
正确用法: wsbx exec cmd /c \"{}\"",
command.iter().map(|s| {
if s.contains(' ') { format!("\"{}\"", s) } else { s.clone() }
}).collect::<Vec<_>>().join(" ")
)
}
87 => "\n 💡 提示: 参数错误,检查命令格式".to_string(),
_ => String::new(),
};
return Err(anyhow!(
"CreateProcessAsUserW failed: {err} (cmd: {cmdline_str}){hint}"
));
}
if let Some(ph_out) = process_handle_out {
*ph_out.lock().unwrap() = Some(pi.hProcess);
}
let (tx_out, rx_out) = std::sync::mpsc::channel::<Vec<u8>>();
let (tx_err, rx_err) = std::sync::mpsc::channel::<Vec<u8>>();
let cancel_out_r = out_r;
let cancel_err_r = err_r;
let t_out = std::thread::spawn(move || {
let mut buf = Vec::new();
let mut tmp = [0u8; 8192];
loop {
let mut read_bytes: u32 = 0;
let ok = ReadFile(
out_r,
tmp.as_mut_ptr() as *mut _,
tmp.len() as u32,
&mut read_bytes,
ptr::null_mut(),
);
if ok == 0 || read_bytes == 0 {
break;
}
buf.extend_from_slice(&tmp[..read_bytes as usize]);
}
let _ = tx_out.send(buf);
});
let t_err = std::thread::spawn(move || {
let mut buf = Vec::new();
let mut tmp = [0u8; 8192];
loop {
let mut read_bytes: u32 = 0;
let ok = ReadFile(
err_r,
tmp.as_mut_ptr() as *mut _,
tmp.len() as u32,
&mut read_bytes,
ptr::null_mut(),
);
if ok == 0 || read_bytes == 0 {
break;
}
buf.extend_from_slice(&tmp[..read_bytes as usize]);
}
let _ = tx_err.send(buf);
});
let timed_out = if let Some(timeout) = timeout_ms {
let res = WaitForSingleObject(pi.hProcess, timeout as u32);
if res == 0x0000_0102 {
TerminateProcess(pi.hProcess, 1);
true
} else {
false
}
} else {
WaitForSingleObject(pi.hProcess, INFINITE);
false
};
let mut exit_code: u32 = 1;
if !timed_out {
GetExitCodeProcess(pi.hProcess, &mut exit_code);
}
if process_handle_out.is_none() {
CloseHandle(pi.hProcess);
}
CloseHandle(pi.hThread);
fn join_reader(
thread: std::thread::JoinHandle<()>,
cancel_handle: HANDLE,
label: &str,
) {
if thread.is_finished() {
let _ = thread.join();
return;
}
let start = std::time::Instant::now();
while start.elapsed() < std::time::Duration::from_secs(5) {
if thread.is_finished() {
let _ = thread.join();
return;
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
eprintln!("⚠ 读取线程 '{label}' 超时(孙进程可能持有管道),强制关闭管道");
unsafe {
CloseHandle(cancel_handle);
}
let _ = thread.join();
}
join_reader(t_out, cancel_out_r, "stdout");
join_reader(t_err, cancel_err_r, "stderr");
let stdout = rx_out.recv().unwrap_or_default();
let stderr = rx_err.recv().unwrap_or_default();
Ok(ProcessResult {
exit_code,
stdout,
stderr,
timed_out,
})
}
pub unsafe fn execute_with_token_stream(
h_token: HANDLE,
command: &[String],
cwd: &Path,
env_map: &HashMap<String, String>,
timeout_ms: Option<u64>,
desktop_name: Option<&str>,
on_stdout: impl FnMut(&[u8]) + Send + 'static,
on_stderr: impl FnMut(&[u8]) + Send + 'static,
) -> Result<ProcessResult> {
if command.is_empty() {
return Err(anyhow!("Empty command"));
}
let ((in_r, in_w), (out_r, out_w), (err_r, err_w)) = create_pipes()?;
let cmdline_str = argv_to_command_line(command);
let mut cmdline = to_wide(&cmdline_str);
let env_block = if env_map.is_empty() {
let current_env: HashMap<String, String> = std::env::vars().collect();
make_env_block(¤t_env)
} else {
make_env_block(env_map)
};
let cwd_wide = to_wide(cwd);
let desktop_str = desktop_name.unwrap_or("Winsta0\\Default");
let desktop_name_wide = to_wide(desktop_str);
let mut si: STARTUPINFOW = std::mem::zeroed();
si.cb = std::mem::size_of::<STARTUPINFOW>() as u32;
si.lpDesktop = desktop_name_wide.as_ptr() as *mut u16;
si.dwFlags = STARTF_USESTDHANDLES;
si.hStdInput = in_r;
si.hStdOutput = out_w;
si.hStdError = err_w;
for &handle in &[in_r, out_w, err_w] {
if SetHandleInformation(handle, HANDLE_FLAG_INHERIT, HANDLE_FLAG_INHERIT) == 0 {
CloseHandle(in_r); CloseHandle(in_w);
CloseHandle(out_r); CloseHandle(out_w);
CloseHandle(err_r); CloseHandle(err_w);
return Err(anyhow!("SetHandleInformation failed: {}", GetLastError()));
}
}
let mut pi: PROCESS_INFORMATION = std::mem::zeroed();
let ok = CreateProcessAsUserW(
h_token,
ptr::null(),
cmdline.as_mut_ptr(),
ptr::null_mut(),
ptr::null_mut(),
1,
CREATE_UNICODE_ENVIRONMENT,
env_block.as_ptr() as *mut _,
cwd_wide.as_ptr(),
&si,
&mut pi,
);
CloseHandle(in_r);
CloseHandle(in_w);
CloseHandle(out_w);
CloseHandle(err_w);
if ok == 0 {
let err = GetLastError();
CloseHandle(out_r);
CloseHandle(err_r);
let hint = match err {
2 => {
let first_cmd = command.first().map(|s| s.as_str()).unwrap_or("");
format!(
"\n 💡 提示: '{first_cmd}' 不是可执行文件(cmd 内部命令需要用 cmd /c 前缀)\n \
正确用法: wsbx exec cmd /c \"{}\"",
command.iter().map(|s| {
if s.contains(' ') { format!("\"{}\"", s) } else { s.clone() }
}).collect::<Vec<_>>().join(" ")
)
}
_ => String::new(),
};
return Err(anyhow!(
"CreateProcessAsUserW failed: {err} (cmd: {cmdline_str}){hint}"
));
}
let (tx_out, rx_out) = std::sync::mpsc::channel::<(Vec<u8>, Vec<u8>)>();
let (tx_done, rx_done) = std::sync::mpsc::channel::<()>();
let mut on_stdout = on_stdout;
let mut on_stderr = on_stderr;
let cancel_out_r = out_r;
let cancel_err_r = err_r;
let t_reader = std::thread::spawn(move || {
let mut buf_out = Vec::new();
let mut buf_err = Vec::new();
let mut tmp = [0u8; 8192];
loop {
let mut any_read = false;
let mut read_bytes: u32 = 0;
let ok = ReadFile(
out_r,
tmp.as_mut_ptr() as *mut _,
tmp.len() as u32,
&mut read_bytes,
ptr::null_mut(),
);
if ok != 0 && read_bytes > 0 {
let chunk = &tmp[..read_bytes as usize];
buf_out.extend_from_slice(chunk);
on_stdout(chunk);
any_read = true;
}
let mut read_bytes: u32 = 0;
let ok = ReadFile(
err_r,
tmp.as_mut_ptr() as *mut _,
tmp.len() as u32,
&mut read_bytes,
ptr::null_mut(),
);
if ok != 0 && read_bytes > 0 {
let chunk = &tmp[..read_bytes as usize];
buf_err.extend_from_slice(chunk);
on_stderr(chunk);
any_read = true;
}
if !any_read {
let mut eof_out = false;
let mut read_bytes: u32 = 0;
if ReadFile(
out_r,
tmp.as_mut_ptr() as *mut _,
tmp.len() as u32,
&mut read_bytes,
ptr::null_mut(),
) == 0 || read_bytes == 0 {
eof_out = true;
} else if read_bytes > 0 {
let chunk = &tmp[..read_bytes as usize];
buf_out.extend_from_slice(chunk);
on_stdout(chunk);
continue;
}
let mut eof_err = false;
let mut read_bytes: u32 = 0;
if ReadFile(
err_r,
tmp.as_mut_ptr() as *mut _,
tmp.len() as u32,
&mut read_bytes,
ptr::null_mut(),
) == 0 || read_bytes == 0 {
eof_err = true;
} else if read_bytes > 0 {
let chunk = &tmp[..read_bytes as usize];
buf_err.extend_from_slice(chunk);
on_stderr(chunk);
continue;
}
if eof_out && eof_err {
break;
}
}
}
let _ = tx_out.send((buf_out, buf_err));
let _ = tx_done.send(());
});
let timed_out = if let Some(timeout) = timeout_ms {
let res = WaitForSingleObject(pi.hProcess, timeout as u32);
if res == 0x0000_0102 {
TerminateProcess(pi.hProcess, 1);
true
} else {
false
}
} else {
WaitForSingleObject(pi.hProcess, INFINITE);
false
};
let mut exit_code: u32 = 1;
if !timed_out {
GetExitCodeProcess(pi.hProcess, &mut exit_code);
}
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
match rx_done.recv_timeout(std::time::Duration::from_secs(5)) {
Ok(_) => {
let _ = t_reader.join();
}
Err(_) => {
eprintln!("⚠ 读取线程超时(孙进程可能持有管道句柄),强制关闭管道");
unsafe {
CloseHandle(cancel_out_r);
CloseHandle(cancel_err_r);
}
let _ = t_reader.join();
}
}
let (stdout, stderr) = rx_out.recv().unwrap_or_default();
Ok(ProcessResult {
exit_code,
stdout,
stderr,
timed_out,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_argv_simple_no_quoting() {
let args = vec!["hello".to_string(), "world".to_string()];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "hello world");
}
#[test]
fn test_argv_with_space_adds_quotes() {
let args = vec!["hello world".to_string(), "test".to_string()];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "\"hello world\" test");
}
#[test]
fn test_argv_with_embedded_quote() {
let args = vec!["it's\"done\"".to_string()];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "\"it's\\\"done\\\"\"");
}
#[test]
fn test_argv_with_trailing_backslash() {
let args = vec!["path\\".to_string(), "end".to_string()];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "\"path\\\\\" end");
}
#[test]
fn test_argv_with_trailing_backslashes_and_quote() {
let args = vec!["a\\\\\"b".to_string()];
let cmd = argv_to_command_line(&args);
let mut expected = String::from("\"a");
expected.push_str(&"\\".repeat(5)); expected.push_str("\"b\"");
assert_eq!(cmd, expected);
}
#[test]
fn test_argv_mixed_complex() {
let args = vec![
"simple".to_string(),
"with space".to_string(),
"trail\\".to_string(),
"quote\"here".to_string(),
];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "simple \"with space\" \"trail\\\\\" \"quote\\\"here\"");
}
#[test]
fn test_argv_trailing_backslashes_and_quote_combined() {
let args = vec!["a\\\"b".to_string()];
let cmd = argv_to_command_line(&args);
let mut expected = String::from("\"a");
expected.push_str(&"\\".repeat(3)); expected.push_str("\"b\"");
assert_eq!(cmd, expected);
}
#[test]
fn test_argv_empty_arg() {
let args = vec!["".to_string()];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "");
}
#[test]
fn test_argv_only_backslashes() {
let args = vec!["\\\\".to_string()];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "\"\\\\\\\\\"");
}
#[test]
fn test_argv_multiple_args_roundtrip() {
let args = vec!["hello".to_string(), "foo bar".to_string(), "end".to_string()];
let cmd = argv_to_command_line(&args);
assert_eq!(cmd, "hello \"foo bar\" end");
}
#[test]
fn test_env_block_single_entry() {
let mut env = HashMap::new();
env.insert("PATH".to_string(), "C:\\bin".to_string());
let block = make_env_block(&env);
assert!(block.len() >= 2, "至少应有 2 个 null: {:?}", block);
assert_eq!(block[block.len() - 1], 0, "应以双 null 结尾");
assert_eq!(block[block.len() - 2], 0, "倒数第二个也应为 null");
}
#[test]
fn test_env_block_double_null_terminated() {
let mut env = HashMap::new();
env.insert("A".to_string(), "1".to_string());
let block = make_env_block(&env);
assert!(block.ends_with(&[0, 0]), "环境块应以双 null 终止");
}
#[test]
fn test_env_block_multiple_entries() {
let mut env = HashMap::new();
env.insert("B".to_string(), "2".to_string());
env.insert("A".to_string(), "1".to_string());
let block = make_env_block(&env);
assert!(block.ends_with(&[0, 0]));
}
#[test]
fn test_env_block_empty_env_produces_only_double_null() {
let env = HashMap::new();
let block = make_env_block(&env);
assert_eq!(block, vec![0, 0], "空环境块应只有双 null");
}
}