use std::io::{Read, Write};
use std::process::{Child, ChildStdin, Command, Stdio};
use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
mpsc,
};
use std::thread;
use std::time::{Duration, Instant};
use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods, Value as LuaValue};
use crate::runtime::encoding::{RuntimeTextEncoding, decode_runtime_text, encode_runtime_text};
#[cfg(unix)]
use libc::{ESRCH, SIGKILL};
#[cfg(windows)]
use std::mem::size_of;
#[cfg(unix)]
use std::os::unix::process::CommandExt;
#[cfg(windows)]
use std::os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle};
#[cfg(windows)]
use std::os::windows::process::CommandExt;
#[cfg(windows)]
use windows_sys::Win32::Foundation::{
ERROR_ACCESS_DENIED, ERROR_INVALID_PARAMETER, HANDLE, INVALID_HANDLE_VALUE, WAIT_FAILED,
WAIT_OBJECT_0, WAIT_TIMEOUT,
};
#[cfg(windows)]
use windows_sys::Win32::System::Diagnostics::ToolHelp::{
CreateToolhelp32Snapshot, PROCESSENTRY32W, Process32FirstW, Process32NextW, TH32CS_SNAPPROCESS,
};
#[cfg(windows)]
use windows_sys::Win32::System::JobObjects::{
AssignProcessToJobObject, CreateJobObjectW, IsProcessInJob, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobObjectExtendedLimitInformation,
SetInformationJobObject, TerminateJobObject,
};
#[cfg(windows)]
use windows_sys::Win32::System::Threading::{
CREATE_BREAKAWAY_FROM_JOB, CREATE_NEW_PROCESS_GROUP, GetCurrentProcess, GetExitCodeProcess,
GetProcessTimes, OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION, PROCESS_TERMINATE,
TerminateProcess, WaitForSingleObject,
};
const DEFAULT_SESSION_READ_TIMEOUT_MS: u64 = 100;
const DEFAULT_SESSION_CLOSE_TIMEOUT_MS: u64 = 1_000;
const DEFAULT_SESSION_MAX_READ_BYTES: usize = 64 * 1024;
const DEFAULT_SESSION_BUFFER_LIMIT_BYTES: usize = 1024 * 1024;
#[cfg(windows)]
const WINDOWS_SYNCHRONIZE_ACCESS: u32 = 0x0010_0000;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct ProcessStatusSnapshot {
running: bool,
exited: bool,
success: Option<bool>,
code: Option<i32>,
}
struct ProcessSessionOpenRequest {
program: String,
args: Vec<String>,
cwd: Option<String>,
stdout_encoding: RuntimeTextEncoding,
stderr_encoding: RuntimeTextEncoding,
stdin_encoding: RuntimeTextEncoding,
buffer_limit_bytes: usize,
}
struct ProcessSessionReadRequest {
timeout_ms: u64,
max_bytes: usize,
until_text: Option<String>,
}
struct ProcessSessionCloseRequest {
timeout_ms: u64,
}
struct ManagedProcessSessionState {
child: Mutex<Child>,
process_tree: ProcessTreeController,
stdin: Mutex<Option<ChildStdin>>,
stdout_buffer: Arc<Mutex<Vec<u8>>>,
stderr_buffer: Arc<Mutex<Vec<u8>>>,
stdout_encoding: RuntimeTextEncoding,
stderr_encoding: RuntimeTextEncoding,
stdin_encoding: RuntimeTextEncoding,
stdout_reader: Mutex<Option<SessionPipeReader>>,
stderr_reader: Mutex<Option<SessionPipeReader>>,
closed: Mutex<bool>,
final_status: Mutex<Option<ProcessStatusSnapshot>>,
}
struct SessionPipeReader {
handle: thread::JoinHandle<()>,
done_rx: mpsc::Receiver<()>,
done: Arc<AtomicBool>,
}
struct ProcessTreeController {
#[cfg(windows)]
strategy: WindowsProcessTreeStrategy,
}
#[cfg(windows)]
struct WindowsProcessJob {
handle: OwnedHandle,
}
#[cfg(windows)]
enum WindowsProcessTreeStrategy {
Job(WindowsProcessJob),
Snapshot,
}
#[cfg(windows)]
struct WindowsSnapshotProcessEntry {
pid: u32,
handle: Option<OwnedHandle>,
}
#[derive(Clone)]
struct ManagedProcessSession {
state: Arc<ManagedProcessSessionState>,
}
impl ManagedProcessSession {
fn open(request: ProcessSessionOpenRequest) -> mlua::Result<Self> {
let mut command = Command::new(&request.program);
command.args(&request.args);
let breakaway_requested = ProcessTreeController::prepare_command(&mut command)
.map_err(|error| mlua::Error::runtime(format!("process.session.open: {error}")))?;
if let Some(cwd) = request.cwd.as_deref() {
command.current_dir(cwd);
}
command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let mut child =
ProcessTreeController::spawn_prepared_command(&mut command, breakaway_requested)
.map_err(|error| mlua::Error::runtime(format!("process.session.open: {error}")))?;
let process_tree = ProcessTreeController::attach(&child)
.map_err(|error| mlua::Error::runtime(format!("process.session.open: {error}")))?;
let stdin = child.stdin.take();
let stdout_buffer = Arc::new(Mutex::new(Vec::new()));
let stderr_buffer = Arc::new(Mutex::new(Vec::new()));
let stdout_reader = child.stdout.take().map(|stdout| {
spawn_session_pipe_reader(
stdout,
stdout_buffer.clone(),
request.buffer_limit_bytes,
"stdout",
)
});
let stderr_reader = child.stderr.take().map(|stderr| {
spawn_session_pipe_reader(
stderr,
stderr_buffer.clone(),
request.buffer_limit_bytes,
"stderr",
)
});
Ok(Self {
state: Arc::new(ManagedProcessSessionState {
child: Mutex::new(child),
process_tree,
stdin: Mutex::new(stdin),
stdout_buffer,
stderr_buffer,
stdout_encoding: request.stdout_encoding,
stderr_encoding: request.stderr_encoding,
stdin_encoding: request.stdin_encoding,
stdout_reader: Mutex::new(stdout_reader),
stderr_reader: Mutex::new(stderr_reader),
closed: Mutex::new(false),
final_status: Mutex::new(None),
}),
})
}
fn write_values(&self, values: MultiValue) -> mlua::Result<bool> {
let mut stdin = self
.state
.stdin
.lock()
.map_err(|_| mlua::Error::runtime("process.session.write: stdin lock poisoned"))?;
let stdin = stdin
.as_mut()
.ok_or_else(|| mlua::Error::runtime("process.session.write: stdin is closed"))?;
for value in values {
let text = lua_value_to_session_text(value, "process.session.write")?;
let bytes = encode_runtime_text(&text, self.state.stdin_encoding)
.map_err(|error| mlua::Error::runtime(format!("process.session.write: {error}")))?;
stdin
.write_all(&bytes)
.map_err(|error| mlua::Error::runtime(format!("process.session.write: {error}")))?;
}
stdin
.flush()
.map_err(|error| mlua::Error::runtime(format!("process.session.write: {error}")))?;
Ok(true)
}
fn read(&self, lua: &Lua, args: MultiValue) -> mlua::Result<Table> {
let request = parse_session_read_request(args)?;
let deadline = Instant::now() + Duration::from_millis(request.timeout_ms);
let mut timed_out = false;
loop {
if self.has_readable_output(&request.until_text)? || self.output_streams_drained()? {
break;
}
if Instant::now() >= deadline {
timed_out = true;
break;
}
thread::sleep(Duration::from_millis(10));
}
let stdout_bytes = drain_buffer(&self.state.stdout_buffer, request.max_bytes)?;
let stderr_bytes = drain_buffer(&self.state.stderr_buffer, request.max_bytes)?;
let stdout = decode_runtime_text(&stdout_bytes, self.state.stdout_encoding);
let stderr = decode_runtime_text(&stderr_bytes, self.state.stderr_encoding);
let result = lua.create_table()?;
result.set("stdout", stdout.text)?;
result.set("stderr", stderr.text)?;
result.set("stdout_encoding", stdout.encoding)?;
result.set("stderr_encoding", stderr.encoding)?;
result.set("stdout_lossy", stdout.lossy)?;
result.set("stderr_lossy", stderr.lossy)?;
result.set("stdout_base64", stdout.base64)?;
result.set("stderr_base64", stderr.base64)?;
result.set("timed_out", timed_out)?;
Ok(result)
}
fn status(&self, lua: &Lua) -> mlua::Result<Table> {
let status = self
.state
.peek_status_snapshot()
.map_err(|error| mlua::Error::runtime(format!("process.session.status: {error}")))?;
process_status_snapshot_to_lua_table(lua, &status)
}
fn close(&self, lua: &Lua, args: MultiValue) -> mlua::Result<Table> {
let request = parse_session_close_request(args)?;
self.close_stdin("process.session.close")?;
let deadline = Instant::now() + Duration::from_millis(request.timeout_ms);
loop {
if self
.state
.peek_status_snapshot()
.map_err(|error| mlua::Error::runtime(format!("process.session.close: {error}")))?
.exited
{
break;
}
if Instant::now() >= deadline {
break;
}
thread::sleep(Duration::from_millis(10));
}
self.mark_closed("process.session.close")?;
let final_status = self.kill_child()?;
self.join_reader_threads("process.session.close")?;
process_status_snapshot_to_lua_table(lua, &final_status)
}
fn kill(&self) -> mlua::Result<bool> {
self.mark_closed("process.session.kill")?;
let _ = self.kill_child()?;
self.join_reader_threads("process.session.kill")?;
Ok(true)
}
fn has_readable_output(&self, until_text: &Option<String>) -> mlua::Result<bool> {
let stdout = self
.state
.stdout_buffer
.lock()
.map_err(|_| mlua::Error::runtime("process.session.read: stdout lock poisoned"))?;
let stderr = self
.state
.stderr_buffer
.lock()
.map_err(|_| mlua::Error::runtime("process.session.read: stderr lock poisoned"))?;
if stdout.is_empty() && stderr.is_empty() {
return Ok(false);
}
if let Some(marker) = until_text {
let stdout_text = decode_runtime_text(&stdout, self.state.stdout_encoding).text;
let stderr_text = decode_runtime_text(&stderr, self.state.stderr_encoding).text;
return Ok(stdout_text.contains(marker) || stderr_text.contains(marker));
}
Ok(true)
}
fn output_streams_drained(&self) -> mlua::Result<bool> {
self.state
.output_readers_drained()
.map_err(|error| mlua::Error::runtime(format!("process.session.read: {error}")))
}
fn kill_child(&self) -> mlua::Result<ProcessStatusSnapshot> {
self.state
.kill_process_tree_and_wait()
.map_err(|error| mlua::Error::runtime(format!("process.session.kill: {error}")))
}
fn close_stdin(&self, operation_name: &str) -> mlua::Result<()> {
self.state
.close_stdin_pipe()
.map_err(|error| mlua::Error::runtime(format!("{operation_name}: {error}")))
}
fn mark_closed(&self, operation_name: &str) -> mlua::Result<()> {
self.state
.mark_closed()
.map_err(|error| mlua::Error::runtime(format!("{operation_name}: {error}")))
}
fn join_reader_threads(&self, operation_name: &str) -> mlua::Result<()> {
self.state
.join_reader_threads()
.map_err(|error| mlua::Error::runtime(format!("{operation_name}: {error}")))
}
}
impl ManagedProcessSessionState {
fn cached_final_status(&self) -> Result<Option<ProcessStatusSnapshot>, String> {
self.final_status
.lock()
.map(|guard| *guard)
.map_err(|_| "final_status lock poisoned".to_string())
}
fn store_final_status(&self, status: ProcessStatusSnapshot) -> Result<(), String> {
let mut final_status = self
.final_status
.lock()
.map_err(|_| "final_status lock poisoned".to_string())?;
*final_status = Some(status);
Ok(())
}
fn reader_completed(
handle: &Mutex<Option<SessionPipeReader>>,
stream_name: &'static str,
) -> Result<bool, String> {
let reader_slot = handle
.lock()
.map_err(|_| format!("{stream_name} reader lock poisoned"))?;
Ok(reader_slot
.as_ref()
.map(|reader| reader.done.load(Ordering::Acquire))
.unwrap_or(true))
}
fn output_readers_drained(&self) -> Result<bool, String> {
Ok(
Self::reader_completed(&self.stdout_reader, "stdout")?
&& Self::reader_completed(&self.stderr_reader, "stderr")?,
)
}
fn close_stdin_pipe(&self) -> Result<(), String> {
let mut stdin = self
.stdin
.lock()
.map_err(|_| "stdin lock poisoned".to_string())?;
stdin.take();
Ok(())
}
fn mark_closed(&self) -> Result<(), String> {
let mut closed = self
.closed
.lock()
.map_err(|_| "closed lock poisoned".to_string())?;
*closed = true;
Ok(())
}
fn peek_status_snapshot(&self) -> Result<ProcessStatusSnapshot, String> {
if let Some(status) = self.cached_final_status()? {
return Ok(status);
}
#[cfg(unix)]
{
let child = self
.child
.lock()
.map_err(|_| "child lock poisoned".to_string())?;
let mut info = std::mem::MaybeUninit::<libc::siginfo_t>::zeroed();
let result = unsafe {
libc::waitid(
libc::P_PID,
child.id() as libc::id_t,
info.as_mut_ptr(),
libc::WEXITED | libc::WNOHANG | libc::WNOWAIT,
)
};
if result != 0 {
let error = std::io::Error::last_os_error();
if error.raw_os_error() == Some(libc::ECHILD) {
return Ok(ProcessStatusSnapshot {
running: false,
exited: true,
success: None,
code: None,
});
}
return Err(format!("waitid: {error}"));
}
let info = unsafe { info.assume_init() };
let reported_pid = unsafe { info.si_pid() };
if reported_pid == 0 {
return Ok(ProcessStatusSnapshot {
running: true,
exited: false,
success: None,
code: None,
});
}
let status_code = unsafe { info.si_status() };
let signal_code = info.si_code;
let (success, code) = if signal_code == libc::CLD_EXITED {
(Some(status_code == 0), Some(status_code))
} else {
(Some(false), None)
};
return Ok(ProcessStatusSnapshot {
running: false,
exited: true,
success,
code,
});
}
#[cfg(windows)]
{
let child = self
.child
.lock()
.map_err(|_| "child lock poisoned".to_string())?;
return peek_windows_process_status(child.as_raw_handle() as HANDLE);
}
#[cfg(all(not(unix), not(windows)))]
{
let mut child = self
.child
.lock()
.map_err(|_| "child lock poisoned".to_string())?;
match child.try_wait().map_err(|error| error.to_string())? {
Some(status) => Ok(ProcessStatusSnapshot {
running: false,
exited: true,
success: Some(status.success()),
code: status.code(),
}),
None => Ok(ProcessStatusSnapshot {
running: true,
exited: false,
success: None,
code: None,
}),
}
}
}
fn kill_process_tree_and_wait(&self) -> Result<ProcessStatusSnapshot, String> {
if let Some(status) = self.cached_final_status()? {
return Ok(status);
}
let mut child = self
.child
.lock()
.map_err(|_| "child lock poisoned".to_string())?;
self.process_tree.terminate(&child)?;
let status = match child.try_wait().map_err(|error| error.to_string())? {
Some(status) => Some(status),
None => child.wait().map(Some).map_err(|error| error.to_string())?,
};
let snapshot = process_status_snapshot_from_exit_status(status);
self.store_final_status(snapshot)?;
Ok(snapshot)
}
fn join_one_reader(
handle: &Mutex<Option<SessionPipeReader>>,
stream_name: &'static str,
) -> Result<(), String> {
let should_take = {
let mut reader_slot = handle
.lock()
.map_err(|_| format!("{stream_name} reader lock poisoned"))?;
let Some(reader) = reader_slot.as_mut() else {
return Ok(());
};
match reader
.done_rx
.recv_timeout(Duration::from_millis(DEFAULT_SESSION_CLOSE_TIMEOUT_MS))
{
Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true,
Err(mpsc::RecvTimeoutError::Timeout) => {
return Err(format!(
"{stream_name} reader shutdown timed out after {DEFAULT_SESSION_CLOSE_TIMEOUT_MS}ms"
));
}
}
};
if should_take {
let reader = handle
.lock()
.map_err(|_| format!("{stream_name} reader lock poisoned"))?
.take();
if let Some(reader) = reader {
reader
.handle
.join()
.map_err(|_| format!("{stream_name} reader panicked"))?;
}
}
Ok(())
}
fn join_reader_threads(&self) -> Result<(), String> {
Self::join_one_reader(&self.stdout_reader, "stdout")?;
Self::join_one_reader(&self.stderr_reader, "stderr")?;
Ok(())
}
fn cleanup_on_drop(&self) {
let _ = self.mark_closed();
let _ = self.close_stdin_pipe();
if self.cached_final_status().ok().flatten().is_none() {
let _ = self.kill_process_tree_and_wait();
}
let _ = self.join_reader_threads();
}
}
impl Drop for ManagedProcessSessionState {
fn drop(&mut self) {
self.cleanup_on_drop();
}
}
impl UserData for ManagedProcessSession {
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
methods.add_method("write", |_, session, values: MultiValue| {
session.write_values(values)
});
methods.add_method("read", |lua, session, args: MultiValue| {
session.read(lua, args)
});
methods.add_method("status", |lua, session, ()| session.status(lua));
methods.add_method("close", |lua, session, args: MultiValue| {
session.close(lua, args)
});
methods.add_method("kill", |_, session, ()| session.kill());
}
}
pub(crate) fn create_process_session_table(
lua: &Lua,
default_encoding: RuntimeTextEncoding,
) -> mlua::Result<Table> {
let table = lua.create_table()?;
table.set(
"open",
lua.create_function(move |lua, spec: LuaValue| {
let request = parse_session_open_request(spec, default_encoding)?;
let session = ManagedProcessSession::open(request)?;
lua.create_userdata(session)
})?,
)?;
Ok(table)
}
fn parse_session_open_request(
value: LuaValue,
default_encoding: RuntimeTextEncoding,
) -> mlua::Result<ProcessSessionOpenRequest> {
let table = match value {
LuaValue::Table(table) => table,
other => {
return Err(mlua::Error::runtime(format!(
"process.session.open: spec must be a table, got {}",
lua_value_type_name(&other)
)));
}
};
let program = require_string_field(&table, "program", "process.session.open")?;
let args = parse_string_array_field(&table, "args", "process.session.open")?;
let cwd = parse_optional_string_field(&table, "cwd", "process.session.open")?;
let encoding = parse_optional_encoding_field(&table, "encoding", "process.session.open")?
.unwrap_or(default_encoding);
let stdout_encoding =
parse_optional_encoding_field(&table, "stdout_encoding", "process.session.open")?
.unwrap_or(encoding);
let stderr_encoding =
parse_optional_encoding_field(&table, "stderr_encoding", "process.session.open")?
.unwrap_or(encoding);
let stdin_encoding =
parse_optional_encoding_field(&table, "stdin_encoding", "process.session.open")?
.unwrap_or(encoding);
let buffer_limit_bytes = parse_optional_usize_field(
&table,
"buffer_limit_bytes",
"process.session.open",
DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
)?;
Ok(ProcessSessionOpenRequest {
program,
args,
cwd,
stdout_encoding,
stderr_encoding,
stdin_encoding,
buffer_limit_bytes,
})
}
fn parse_session_read_request(args: MultiValue) -> mlua::Result<ProcessSessionReadRequest> {
let mut values = args.into_iter();
let value = values.next().unwrap_or(LuaValue::Nil);
match value {
LuaValue::Nil => Ok(ProcessSessionReadRequest {
timeout_ms: DEFAULT_SESSION_READ_TIMEOUT_MS,
max_bytes: DEFAULT_SESSION_MAX_READ_BYTES,
until_text: None,
}),
LuaValue::Table(table) => Ok(ProcessSessionReadRequest {
timeout_ms: parse_optional_u64_field(
&table,
"timeout_ms",
"process.session.read",
DEFAULT_SESSION_READ_TIMEOUT_MS,
)?,
max_bytes: parse_optional_usize_field(
&table,
"max_bytes",
"process.session.read",
DEFAULT_SESSION_MAX_READ_BYTES,
)?,
until_text: parse_optional_string_field(&table, "until_text", "process.session.read")?,
}),
other => Err(mlua::Error::runtime(format!(
"process.session.read: options must be a table, got {}",
lua_value_type_name(&other)
))),
}
}
fn parse_session_close_request(args: MultiValue) -> mlua::Result<ProcessSessionCloseRequest> {
let mut values = args.into_iter();
let value = values.next().unwrap_or(LuaValue::Nil);
match value {
LuaValue::Nil => Ok(ProcessSessionCloseRequest {
timeout_ms: DEFAULT_SESSION_CLOSE_TIMEOUT_MS,
}),
LuaValue::Table(table) => Ok(ProcessSessionCloseRequest {
timeout_ms: parse_optional_u64_field(
&table,
"timeout_ms",
"process.session.close",
DEFAULT_SESSION_CLOSE_TIMEOUT_MS,
)?,
}),
other => Err(mlua::Error::runtime(format!(
"process.session.close: options must be a table, got {}",
lua_value_type_name(&other)
))),
}
}
fn spawn_session_pipe_reader<R>(
mut reader: R,
target: Arc<Mutex<Vec<u8>>>,
limit_bytes: usize,
stream_name: &'static str,
) -> SessionPipeReader
where
R: Read + Send + 'static,
{
let (done_tx, done_rx) = mpsc::channel();
let done = Arc::new(AtomicBool::new(false));
let done_flag = done.clone();
let handle = thread::spawn(move || {
let mut chunk = [0_u8; 4096];
loop {
match reader.read(&mut chunk) {
Ok(0) => break,
Ok(count) => {
if let Ok(mut buffer) = target.lock() {
append_bounded(&mut buffer, &chunk[..count], limit_bytes);
} else {
break;
}
}
Err(error) => {
crate::runtime_logging::warn(format!(
"[LuaSkill:warn] process.session {stream_name} reader failed: {error}"
));
break;
}
}
}
done_flag.store(true, Ordering::Release);
let _ = done_tx.send(());
});
SessionPipeReader {
handle,
done_rx,
done,
}
}
impl ProcessTreeController {
fn prepare_command(command: &mut Command) -> Result<bool, String> {
#[cfg(unix)]
{
command.process_group(0);
return Ok(false);
}
#[cfg(windows)]
{
let in_job = current_process_is_in_job()?;
let creation_flags = if in_job {
CREATE_NEW_PROCESS_GROUP | CREATE_BREAKAWAY_FROM_JOB
} else {
CREATE_NEW_PROCESS_GROUP
};
command.creation_flags(creation_flags);
return Ok(in_job);
}
#[cfg(not(any(unix, windows)))]
{
let _ = command;
Ok(false)
}
}
fn spawn_prepared_command(
command: &mut Command,
breakaway_requested: bool,
) -> Result<Child, std::io::Error> {
#[cfg(windows)]
{
match command.spawn() {
Ok(child) => Ok(child),
Err(error)
if breakaway_requested
&& error.raw_os_error() == Some(ERROR_ACCESS_DENIED as i32) =>
{
crate::runtime_logging::warn(
"[LuaSkill:warn] process.session falling back to inherited host job because CREATE_BREAKAWAY_FROM_JOB was denied"
.to_string(),
);
command.creation_flags(CREATE_NEW_PROCESS_GROUP);
command.spawn()
}
Err(error) => Err(error),
}
}
#[cfg(not(windows))]
{
let _ = breakaway_requested;
command.spawn()
}
}
fn attach(child: &Child) -> Result<Self, String> {
#[cfg(windows)]
{
let job = WindowsProcessJob::create()?;
match job.assign(child) {
Ok(()) => {
return Ok(Self {
strategy: WindowsProcessTreeStrategy::Job(job),
});
}
Err(WindowsJobAssignError::AccessDenied(message)) => {
crate::runtime_logging::warn(format!(
"[LuaSkill:warn] process.session is reusing ToolHelp process-tree fallback because Job Object assignment was denied: {message}"
));
return Ok(Self {
strategy: WindowsProcessTreeStrategy::Snapshot,
});
}
Err(WindowsJobAssignError::Other(message)) => return Err(message),
}
}
#[cfg(not(windows))]
{
let _ = child;
Ok(Self {})
}
}
fn terminate(&self, _child: &Child) -> Result<(), String> {
#[cfg(unix)]
{
let result = unsafe { libc::kill(-(_child.id() as i32), SIGKILL) };
if result == 0 {
return Ok(());
}
let error = std::io::Error::last_os_error();
if error.raw_os_error() == Some(ESRCH) {
return Ok(());
}
return Err(format!("kill process group: {error}"));
}
#[cfg(windows)]
{
return self.strategy.terminate(_child);
}
#[cfg(not(any(unix, windows)))]
{
let _ = _child;
Ok(())
}
}
}
#[cfg(windows)]
enum WindowsJobAssignError {
AccessDenied(String),
Other(String),
}
#[cfg(windows)]
impl WindowsProcessTreeStrategy {
fn terminate(&self, child: &Child) -> Result<(), String> {
match self {
Self::Job(job) => job.terminate(),
Self::Snapshot => terminate_windows_process_tree_snapshot(child),
}
}
}
#[cfg(windows)]
fn current_process_is_in_job() -> Result<bool, String> {
let mut in_job = 0;
let status = unsafe { IsProcessInJob(GetCurrentProcess(), std::ptr::null_mut(), &mut in_job) };
if status == 0 {
return Err(format!(
"IsProcessInJob: {}",
std::io::Error::last_os_error()
));
}
Ok(in_job != 0)
}
#[cfg(windows)]
fn terminate_windows_process_tree_snapshot(child: &Child) -> Result<(), String> {
let root_pid = child.id();
let descendants = collect_windows_descendant_processes(root_pid)?;
let mut first_error: Option<String> = None;
for descendant in descendants.into_iter().rev() {
if let Some(handle) = descendant.handle {
let label = format!("process {}", descendant.pid);
if let Err(error) = terminate_windows_process_handle(
handle.as_raw_handle() as HANDLE,
&label,
false,
) {
if first_error.is_none() {
first_error = Some(error);
}
}
}
}
if let Err(error) = terminate_windows_process_handle(
child.as_raw_handle() as HANDLE,
"process.session root process",
true,
) {
if first_error.is_none() {
first_error = Some(error);
}
}
if let Some(error) = first_error {
return Err(error);
}
Ok(())
}
#[cfg(windows)]
fn peek_windows_process_status(handle: HANDLE) -> Result<ProcessStatusSnapshot, String> {
let wait_status = unsafe { WaitForSingleObject(handle, 0) };
match wait_status {
WAIT_TIMEOUT => Ok(ProcessStatusSnapshot {
running: true,
exited: false,
success: None,
code: None,
}),
WAIT_OBJECT_0 => {
let mut exit_code = 0_u32;
let status = unsafe { GetExitCodeProcess(handle, &mut exit_code) };
if status == 0 {
return Err(format!(
"GetExitCodeProcess: {}",
std::io::Error::last_os_error()
));
}
let code = exit_code as i32;
Ok(ProcessStatusSnapshot {
running: false,
exited: true,
success: Some(code == 0),
code: Some(code),
})
}
WAIT_FAILED => Err(format!(
"WaitForSingleObject(process status): {}",
std::io::Error::last_os_error()
)),
other => Err(format!(
"WaitForSingleObject(process status) returned unexpected status {other}"
)),
}
}
#[cfg(windows)]
fn collect_windows_descendant_processes(
root_pid: u32,
) -> Result<Vec<WindowsSnapshotProcessEntry>, String> {
let snapshot = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) };
if snapshot == INVALID_HANDLE_VALUE {
return Err(format!(
"CreateToolhelp32Snapshot: {}",
std::io::Error::last_os_error()
));
}
let snapshot = unsafe { OwnedHandle::from_raw_handle(snapshot as _) };
let snapshot_captured_ticks = current_windows_time_ticks()?;
let mut entry: PROCESSENTRY32W = unsafe { std::mem::zeroed() };
entry.dwSize = size_of::<PROCESSENTRY32W>() as u32;
let mut children_by_parent =
std::collections::HashMap::<u32, Vec<WindowsSnapshotProcessEntry>>::new();
let mut has_entry =
unsafe { Process32FirstW(snapshot.as_raw_handle() as HANDLE, &mut entry) } != 0;
while has_entry {
let pid = entry.th32ProcessID;
children_by_parent
.entry(entry.th32ParentProcessID)
.or_default()
.push(WindowsSnapshotProcessEntry {
pid,
handle: try_open_windows_process_for_snapshot(pid, snapshot_captured_ticks)?,
});
has_entry = unsafe { Process32NextW(snapshot.as_raw_handle() as HANDLE, &mut entry) } != 0;
}
let mut ordered = Vec::new();
let mut stack = vec![root_pid];
while let Some(parent_pid) = stack.pop() {
if let Some(children) = children_by_parent.remove(&parent_pid) {
for child in children {
stack.push(child.pid);
ordered.push(child);
}
}
}
Ok(ordered)
}
#[cfg(windows)]
fn try_open_windows_process_for_snapshot(
pid: u32,
snapshot_started_ticks: u64,
) -> Result<Option<OwnedHandle>, String> {
let handle = unsafe {
OpenProcess(
PROCESS_QUERY_LIMITED_INFORMATION | PROCESS_TERMINATE | WINDOWS_SYNCHRONIZE_ACCESS,
0,
pid,
)
};
if handle.is_null() {
let error = std::io::Error::last_os_error();
match error.raw_os_error() {
Some(code)
if code == ERROR_ACCESS_DENIED as i32 || code == ERROR_INVALID_PARAMETER as i32 =>
{
return Ok(None);
}
_ => return Err(format!("OpenProcess({pid}): {error}")),
}
}
let handle = unsafe { OwnedHandle::from_raw_handle(handle as _) };
let created_ticks = get_windows_process_creation_ticks(handle.as_raw_handle() as HANDLE)?;
if created_ticks > snapshot_started_ticks {
return Ok(None);
}
Ok(Some(handle))
}
#[cfg(windows)]
fn current_windows_time_ticks() -> Result<u64, String> {
const WINDOWS_TO_UNIX_EPOCH_SECONDS: u64 = 11_644_473_600;
let unix_elapsed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|error| format!("SystemTime before UNIX_EPOCH: {error}"))?;
Ok((unix_elapsed.as_secs() + WINDOWS_TO_UNIX_EPOCH_SECONDS) * 10_000_000
+ u64::from(unix_elapsed.subsec_nanos() / 100))
}
#[cfg(windows)]
fn get_windows_process_creation_ticks(handle: HANDLE) -> Result<u64, String> {
let mut creation_time = windows_sys::Win32::Foundation::FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let mut exit_time = windows_sys::Win32::Foundation::FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let mut kernel_time = windows_sys::Win32::Foundation::FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let mut user_time = windows_sys::Win32::Foundation::FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let status = unsafe {
GetProcessTimes(
handle,
&mut creation_time,
&mut exit_time,
&mut kernel_time,
&mut user_time,
)
};
if status == 0 {
return Err(format!(
"GetProcessTimes: {}",
std::io::Error::last_os_error()
));
}
Ok(((creation_time.dwHighDateTime as u64) << 32) | creation_time.dwLowDateTime as u64)
}
#[cfg(windows)]
fn terminate_windows_process_handle(
handle: HANDLE,
label: &str,
allow_timeout: bool,
) -> Result<(), String> {
let terminate_status = unsafe { TerminateProcess(handle, 1) };
if terminate_status == 0 {
let error = std::io::Error::last_os_error();
match error.raw_os_error() {
Some(code)
if code == ERROR_ACCESS_DENIED as i32 || code == ERROR_INVALID_PARAMETER as i32 => {
}
_ => return Err(format!("TerminateProcess({label}): {error}")),
}
}
let wait_status =
unsafe { WaitForSingleObject(handle, DEFAULT_SESSION_CLOSE_TIMEOUT_MS as u32) };
match wait_status {
WAIT_OBJECT_0 => Ok(()),
WAIT_TIMEOUT if allow_timeout => Ok(()),
WAIT_TIMEOUT => Err(format!(
"WaitForSingleObject({label}) timed out after {DEFAULT_SESSION_CLOSE_TIMEOUT_MS}ms"
)),
WAIT_FAILED => Err(format!(
"WaitForSingleObject({label}): {}",
std::io::Error::last_os_error()
)),
other => Err(format!(
"WaitForSingleObject({label}) returned unexpected status {other}"
)),
}
}
#[cfg(windows)]
impl WindowsProcessJob {
fn create() -> Result<Self, String> {
let raw = unsafe { CreateJobObjectW(std::ptr::null(), std::ptr::null()) };
if raw.is_null() {
return Err(format!(
"CreateJobObjectW: {}",
std::io::Error::last_os_error()
));
}
let handle = unsafe { OwnedHandle::from_raw_handle(raw as _) };
let mut info: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = unsafe { std::mem::zeroed() };
info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
let status = unsafe {
SetInformationJobObject(
handle.as_raw_handle() as _,
JobObjectExtendedLimitInformation,
&info as *const _ as *const _,
size_of::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() as u32,
)
};
if status == 0 {
return Err(format!(
"SetInformationJobObject: {}",
std::io::Error::last_os_error()
));
}
Ok(Self { handle })
}
fn assign(&self, child: &Child) -> Result<(), WindowsJobAssignError> {
let status = unsafe {
AssignProcessToJobObject(self.handle.as_raw_handle() as _, child.as_raw_handle() as _)
};
if status == 0 {
let error = std::io::Error::last_os_error();
return match error.raw_os_error() {
Some(code) if code == ERROR_ACCESS_DENIED as i32 => {
Err(WindowsJobAssignError::AccessDenied(error.to_string()))
}
_ => Err(WindowsJobAssignError::Other(format!(
"AssignProcessToJobObject: {error}"
))),
};
}
Ok(())
}
fn terminate(&self) -> Result<(), String> {
let status = unsafe { TerminateJobObject(self.handle.as_raw_handle() as _, 1) };
if status == 0 {
let error = std::io::Error::last_os_error();
if error.raw_os_error() == Some(ERROR_ACCESS_DENIED as i32) {
return Ok(());
}
return Err(format!("TerminateJobObject: {error}"));
}
Ok(())
}
}
fn append_bounded(buffer: &mut Vec<u8>, bytes: &[u8], limit_bytes: usize) {
let limit_bytes = limit_bytes.max(1);
if bytes.len() >= limit_bytes {
buffer.clear();
buffer.extend_from_slice(&bytes[bytes.len() - limit_bytes..]);
return;
}
let total_len = buffer.len() + bytes.len();
if total_len > limit_bytes {
let overflow = total_len - limit_bytes;
buffer.drain(0..overflow.min(buffer.len()));
}
buffer.extend_from_slice(bytes);
}
fn drain_buffer(buffer: &Arc<Mutex<Vec<u8>>>, max_bytes: usize) -> mlua::Result<Vec<u8>> {
let mut guard = buffer
.lock()
.map_err(|_| mlua::Error::runtime("process.session.read: output lock poisoned"))?;
let count = max_bytes.min(guard.len());
Ok(guard.drain(0..count).collect())
}
fn process_status_snapshot_from_exit_status(
status: Option<std::process::ExitStatus>,
) -> ProcessStatusSnapshot {
match status {
Some(status) => ProcessStatusSnapshot {
running: false,
exited: true,
success: Some(status.success()),
code: status.code(),
},
None => ProcessStatusSnapshot {
running: true,
exited: false,
success: None,
code: None,
},
}
}
fn process_status_snapshot_to_lua_table(
lua: &Lua,
status: &ProcessStatusSnapshot,
) -> mlua::Result<Table> {
let table = lua.create_table()?;
table.set("running", status.running)?;
table.set("exited", status.exited)?;
match status.success {
Some(success) => table.set("success", success)?,
None => table.set("success", LuaValue::Nil)?,
}
match status.code {
Some(code) => table.set("code", code)?,
None => table.set("code", LuaValue::Nil)?,
}
Ok(table)
}
fn require_string_field(table: &Table, key: &str, fn_name: &str) -> mlua::Result<String> {
let value: LuaValue = table.get(key)?;
require_string_value(value, fn_name, key, false)
}
fn parse_optional_string_field(
table: &Table,
key: &str,
fn_name: &str,
) -> mlua::Result<Option<String>> {
let value: LuaValue = table.get(key)?;
match value {
LuaValue::Nil => Ok(None),
value => Ok(Some(require_string_value(value, fn_name, key, false)?)),
}
}
fn parse_optional_encoding_field(
table: &Table,
key: &str,
fn_name: &str,
) -> mlua::Result<Option<RuntimeTextEncoding>> {
let value: LuaValue = table.get(key)?;
match value {
LuaValue::Nil => Ok(None),
LuaValue::String(text) => {
let label = text
.to_str()
.map_err(|_| mlua::Error::runtime(format!("{fn_name}: {key} must be UTF-8")))?;
RuntimeTextEncoding::parse(label.as_ref())
.map(Some)
.map_err(|error| mlua::Error::runtime(format!("{fn_name}: {error}")))
}
other => Err(mlua::Error::runtime(format!(
"{fn_name}: {key} must be a string, got {}",
lua_value_type_name(&other)
))),
}
}
fn parse_string_array_field(table: &Table, key: &str, fn_name: &str) -> mlua::Result<Vec<String>> {
let value: LuaValue = table.get(key)?;
match value {
LuaValue::Nil => Ok(Vec::new()),
LuaValue::Table(items) => {
let mut output = Vec::new();
for pair in items.sequence_values::<LuaValue>() {
output.push(require_string_value(pair?, fn_name, key, true)?);
}
Ok(output)
}
other => Err(mlua::Error::runtime(format!(
"{fn_name}: {key} must be an array table, got {}",
lua_value_type_name(&other)
))),
}
}
fn parse_optional_u64_field(
table: &Table,
key: &str,
fn_name: &str,
default_value: u64,
) -> mlua::Result<u64> {
let value: LuaValue = table.get(key)?;
match value {
LuaValue::Nil => Ok(default_value),
LuaValue::Integer(number) if number > 0 => Ok(number as u64),
LuaValue::Number(number) if number.is_finite() && number > 0.0 => Ok(number as u64),
other => Err(mlua::Error::runtime(format!(
"{fn_name}: {key} must be a positive number, got {}",
lua_value_type_name(&other)
))),
}
}
fn parse_optional_usize_field(
table: &Table,
key: &str,
fn_name: &str,
default_value: usize,
) -> mlua::Result<usize> {
let value = parse_optional_u64_field(table, key, fn_name, default_value as u64)?;
usize::try_from(value).map_err(|_| {
mlua::Error::runtime(format!("{fn_name}: {key} is too large for this platform"))
})
}
fn lua_value_to_session_text(value: LuaValue, fn_name: &str) -> mlua::Result<String> {
match value {
LuaValue::String(text) => Ok(text
.to_str()
.map_err(|_| mlua::Error::runtime(format!("{fn_name}: string must be valid UTF-8")))?
.to_string()),
LuaValue::Integer(number) => Ok(number.to_string()),
LuaValue::Number(number) => Ok(number.to_string()),
LuaValue::Boolean(flag) => Ok(flag.to_string()),
other => Err(mlua::Error::runtime(format!(
"{fn_name}: unsupported value {}",
lua_value_type_name(&other)
))),
}
}
fn require_string_value(
value: LuaValue,
fn_name: &str,
param_name: &str,
allow_blank: bool,
) -> mlua::Result<String> {
let text = match value {
LuaValue::String(text) => text
.to_str()
.map_err(|_| {
mlua::Error::runtime(format!("{fn_name}: {param_name} must be valid UTF-8"))
})?
.to_string(),
other => {
return Err(mlua::Error::runtime(format!(
"{fn_name}: {param_name} must be a string, got {}",
lua_value_type_name(&other)
)));
}
};
if !allow_blank && text.trim().is_empty() {
return Err(mlua::Error::runtime(format!(
"{fn_name}: {param_name} must not be empty"
)));
}
if text.contains('\0') {
return Err(mlua::Error::runtime(format!(
"{fn_name}: {param_name} must not contain NUL bytes"
)));
}
Ok(text)
}
fn lua_value_type_name(value: &LuaValue) -> &'static str {
match value {
LuaValue::Nil => "nil",
LuaValue::Boolean(_) => "boolean",
LuaValue::LightUserData(_) => "lightuserdata",
LuaValue::Integer(_) | LuaValue::Number(_) => "number",
LuaValue::String(_) => "string",
LuaValue::Table(_) => "table",
LuaValue::Function(_) => "function",
LuaValue::Thread(_) => "thread",
LuaValue::UserData(_) => "userdata",
LuaValue::Error(_) => "error",
LuaValue::Other(_) => "other",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::encoding::default_runtime_text_encoding;
use std::process::Command;
use std::thread;
use std::time::{Duration, Instant};
fn make_drop_cleanup_request() -> ProcessSessionOpenRequest {
let encoding = default_runtime_text_encoding();
if cfg!(windows) {
ProcessSessionOpenRequest {
program: "powershell".to_string(),
args: vec![
"-NoProfile".to_string(),
"-Command".to_string(),
"Start-Sleep -Seconds 30".to_string(),
],
cwd: None,
stdout_encoding: encoding,
stderr_encoding: encoding,
stdin_encoding: encoding,
buffer_limit_bytes: DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
}
} else {
ProcessSessionOpenRequest {
program: "sleep".to_string(),
args: vec!["30".to_string()],
cwd: None,
stdout_encoding: encoding,
stderr_encoding: encoding,
stdin_encoding: encoding,
buffer_limit_bytes: DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
}
}
}
fn make_descendant_cleanup_request() -> ProcessSessionOpenRequest {
let encoding = default_runtime_text_encoding();
if cfg!(windows) {
ProcessSessionOpenRequest {
program: "powershell".to_string(),
args: vec![
"-NoProfile".to_string(),
"-Command".to_string(),
"$child = Start-Process powershell -PassThru -WindowStyle Hidden -ArgumentList '-NoProfile','-Command','Start-Sleep -Seconds 30'; [Console]::Out.WriteLine($child.Id); [Console]::Out.Flush()".to_string(),
],
cwd: None,
stdout_encoding: encoding,
stderr_encoding: encoding,
stdin_encoding: encoding,
buffer_limit_bytes: DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
}
} else {
ProcessSessionOpenRequest {
program: "sh".to_string(),
args: vec!["-c".to_string(), "sleep 30 & echo $!; exit 0".to_string()],
cwd: None,
stdout_encoding: encoding,
stderr_encoding: encoding,
stdin_encoding: encoding,
buffer_limit_bytes: DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
}
}
}
fn make_immediate_exit_request() -> ProcessSessionOpenRequest {
let encoding = default_runtime_text_encoding();
if cfg!(windows) {
ProcessSessionOpenRequest {
program: "cmd".to_string(),
args: vec!["/c".to_string(), "exit 0".to_string()],
cwd: None,
stdout_encoding: encoding,
stderr_encoding: encoding,
stdin_encoding: encoding,
buffer_limit_bytes: DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
}
} else {
ProcessSessionOpenRequest {
program: "sh".to_string(),
args: vec!["-c".to_string(), "exit 0".to_string()],
cwd: None,
stdout_encoding: encoding,
stderr_encoding: encoding,
stdin_encoding: encoding,
buffer_limit_bytes: DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
}
}
}
fn process_exists(pid: u32) -> bool {
if cfg!(windows) {
Command::new("powershell")
.args([
"-NoProfile",
"-Command",
&format!(
"if (Get-Process -Id {} -ErrorAction SilentlyContinue) {{ exit 0 }} else {{ exit 1 }}",
pid
),
])
.status()
.map(|status| status.success())
.unwrap_or(false)
} else {
Command::new("sh")
.args(["-c", &format!("kill -0 {} 2>/dev/null", pid)])
.status()
.map(|status| status.success())
.unwrap_or(false)
}
}
fn assert_process_exits(pid: u32, timeout: Duration) {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if !process_exists(pid) {
return;
}
thread::sleep(Duration::from_millis(50));
}
panic!("process {pid} should have exited after session drop");
}
fn wait_for_descendant_pid(session: &ManagedProcessSession, timeout: Duration) -> u32 {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
let stdout = session
.state
.stdout_buffer
.lock()
.expect("lock stdout buffer");
if !stdout.is_empty() {
let pid_text = stdout
.iter()
.filter_map(|byte| match byte {
b'0'..=b'9' => Some(char::from(*byte)),
b'\r' | b'\n' => Some('\n'),
_ => None,
})
.collect::<String>()
.lines()
.next()
.unwrap_or_default()
.trim()
.to_string();
drop(stdout);
if let Ok(pid) = pid_text.parse::<u32>() {
return pid;
}
}
thread::sleep(Duration::from_millis(25));
}
panic!("descendant pid should be published before cleanup");
}
#[test]
fn dropping_process_session_kills_child_process() {
let session = ManagedProcessSession::open(make_drop_cleanup_request())
.expect("open drop cleanup session");
let pid = session.state.child.lock().expect("lock child process").id();
assert!(
process_exists(pid),
"child process should be running before drop"
);
drop(session);
assert_process_exits(pid, Duration::from_secs(5));
}
#[test]
fn killing_process_session_terminates_descendants_and_releases_readers() {
let session = ManagedProcessSession::open(make_descendant_cleanup_request())
.expect("open descendant cleanup session");
let descendant_pid = wait_for_descendant_pid(&session, Duration::from_secs(15));
assert!(
process_exists(descendant_pid),
"descendant process should be running before cleanup"
);
session
.mark_closed("process.session.test")
.expect("mark process session closed");
let start = Instant::now();
session
.kill_child()
.expect("kill descendant process tree cleanly");
session
.join_reader_threads("process.session.test")
.expect("join process session readers");
assert!(
start.elapsed() < Duration::from_secs(5),
"process session cleanup should not block after tree termination"
);
assert_process_exits(descendant_pid, Duration::from_secs(5));
}
#[test]
fn process_session_tree_teardown_is_idempotent_after_explicit_kill() {
let session = ManagedProcessSession::open(make_drop_cleanup_request())
.expect("open idempotent session");
session
.mark_closed("process.session.test")
.expect("mark idempotent session closed");
let first = session
.kill_child()
.expect("first process tree teardown should succeed");
let second = session
.kill_child()
.expect("second process tree teardown should reuse cached final status");
assert_eq!(first, second);
}
#[test]
fn join_one_reader_timeout_preserves_reader_handle() {
let (release_tx, release_rx) = mpsc::channel();
let (done_tx, done_rx) = mpsc::channel();
let done = Arc::new(AtomicBool::new(false));
let done_flag = done.clone();
let handle = thread::spawn(move || {
release_rx.recv().expect("release test reader");
done_flag.store(true, Ordering::Release);
let _ = done_tx.send(());
});
let reader_slot = Mutex::new(Some(SessionPipeReader {
handle,
done_rx,
done,
}));
let error = ManagedProcessSessionState::join_one_reader(&reader_slot, "test")
.expect_err("reader join should time out before release");
assert!(
error.contains("timed out"),
"timeout error should mention shutdown timeout, got: {error}"
);
assert!(
reader_slot
.lock()
.expect("lock reader slot after timeout")
.is_some(),
"reader handle should stay available after timeout"
);
release_tx.send(()).expect("release test reader thread");
ManagedProcessSessionState::join_one_reader(&reader_slot, "test")
.expect("reader join should succeed after release");
assert!(
reader_slot
.lock()
.expect("lock reader slot after join")
.is_none(),
"reader handle should be removed after successful join"
);
}
#[test]
fn closing_process_session_after_child_exit_still_cleans_descendants() {
let lua = Lua::new();
let session = ManagedProcessSession::open(make_descendant_cleanup_request())
.expect("open close descendant cleanup session");
let descendant_pid = wait_for_descendant_pid(&session, Duration::from_secs(15));
assert!(
process_exists(descendant_pid),
"descendant process should be running before close cleanup"
);
let start = Instant::now();
let status = session
.close(&lua, MultiValue::new())
.expect("close descendant cleanup session");
assert!(
start.elapsed() < Duration::from_secs(5),
"process.session.close should not block after descendant cleanup"
);
let exited: bool = status.get("exited").expect("read close exited flag");
assert!(exited, "close should report one exited process status");
assert_process_exits(descendant_pid, Duration::from_secs(5));
}
#[test]
fn read_waits_for_descendant_output_after_root_exit() {
let lua = Lua::new();
let session = ManagedProcessSession::open(make_immediate_exit_request())
.expect("open immediate exit session");
let deadline = Instant::now() + Duration::from_secs(5);
while Instant::now() < deadline {
if session
.state
.peek_status_snapshot()
.expect("peek immediate exit status")
.exited
{
break;
}
thread::sleep(Duration::from_millis(10));
}
assert!(
session
.state
.peek_status_snapshot()
.expect("recheck immediate exit status")
.exited,
"immediate exit process should finish before read regression check"
);
session
.state
.join_reader_threads()
.expect("join real readers before installing test readers");
let install_test_reader =
|| -> (SessionPipeReader, mpsc::Sender<()>, Arc<AtomicBool>) {
let (release_tx, release_rx) = mpsc::channel();
let (done_tx, done_rx) = mpsc::channel();
let done = Arc::new(AtomicBool::new(false));
let done_flag = done.clone();
let handle = thread::spawn(move || {
release_rx.recv().expect("release synthetic session reader");
done_flag.store(true, Ordering::Release);
let _ = done_tx.send(());
});
(
SessionPipeReader {
handle,
done_rx,
done: done.clone(),
},
release_tx,
done,
)
};
let (stdout_reader, stdout_release_tx, _) = install_test_reader();
let (stderr_reader, stderr_release_tx, _) = install_test_reader();
*session
.state
.stdout_reader
.lock()
.expect("lock stdout reader slot for synthetic install") = Some(stdout_reader);
*session
.state
.stderr_reader
.lock()
.expect("lock stderr reader slot for synthetic install") = Some(stderr_reader);
let stdout_buffer = session.state.stdout_buffer.clone();
let release_producer = thread::spawn(move || {
thread::sleep(Duration::from_millis(250));
let mut buffer = stdout_buffer
.lock()
.expect("lock stdout buffer for synthetic descendant output");
append_bounded(
&mut buffer,
b"child-ready\n",
DEFAULT_SESSION_BUFFER_LIMIT_BYTES,
);
drop(buffer);
stdout_release_tx
.send(())
.expect("release synthetic stdout reader");
stderr_release_tx
.send(())
.expect("release synthetic stderr reader");
});
let options = lua.create_table().expect("create read options");
options
.set("timeout_ms", 3_000)
.expect("set read timeout");
options
.set("until_text", "child-ready")
.expect("set read marker");
let mut args = MultiValue::new();
args.push_back(LuaValue::Table(options));
let result = session.read(&lua, args).expect("read descendant output");
let stdout: String = result.get("stdout").expect("read stdout text");
let timed_out: bool = result.get("timed_out").expect("read timed_out flag");
assert!(
!timed_out,
"read should finish from descendant output instead of timing out"
);
assert!(
stdout.contains("child-ready"),
"read should capture descendant output after root exit, got: {stdout:?}"
);
release_producer
.join()
.expect("join synthetic descendant output producer");
session
.state
.join_reader_threads()
.expect("join synthetic session readers");
}
#[cfg(windows)]
#[test]
fn windows_snapshot_open_rejects_future_identity() {
let handle = try_open_windows_process_for_snapshot(std::process::id(), 0)
.expect("open current process for snapshot identity test");
assert!(
handle.is_none(),
"process created after cutoff should be rejected to avoid PID reuse confusion"
);
}
#[cfg(windows)]
#[test]
fn windows_snapshot_open_accepts_existing_identity_before_cutoff() {
let cutoff = current_windows_time_ticks().expect("capture current windows cutoff");
let handle = try_open_windows_process_for_snapshot(std::process::id(), cutoff)
.expect("open current process before cutoff");
assert!(
handle.is_some(),
"existing process should be accepted when it predates the snapshot cutoff"
);
}
}