use portable_pty::{native_pty_system, CommandBuilder, MasterPty, PtySize};
use std::io::{self, Read, Write};
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
#[cfg(unix)]
use libc;
#[derive(Debug)]
pub struct RunResult {
pub exit_code: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum PtyError {
#[error("failed to open PTY: {0}")]
Open(anyhow::Error),
#[error("failed to spawn command: {0}")]
Spawn(anyhow::Error),
#[error("failed to clone PTY reader: {0}")]
Reader(anyhow::Error),
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("timeout: child did not exit within {0}s")]
Timeout(u64),
}
fn join_reader(
handle: JoinHandle<()>,
stop: &Arc<Mutex<bool>>,
master: Box<dyn portable_pty::MasterPty + Send>,
) {
let deadline = Instant::now() + Duration::from_secs(2);
loop {
if handle.is_finished() {
let _ = handle.join();
drop(master);
return;
}
if Instant::now() > deadline {
break;
}
thread::sleep(Duration::from_millis(50));
}
if let Ok(mut g) = stop.lock() {
*g = true;
}
drop(master);
}
fn send_exit_command(master: &dyn MasterPty, signal_path: &Path) {
if let Ok(mut writer) = master.take_writer() {
let _ = writer.write_all(b"/exit\n");
let _ = writer.flush();
}
let _ = std::fs::remove_file(signal_path);
}
pub fn run(
argv: &[String],
cwd: &Path,
timeout_secs: u64,
exit_signal: Option<&Path>,
) -> Result<RunResult, PtyError> {
let pty_system = native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows: 50,
cols: 200,
pixel_width: 0,
pixel_height: 0,
})
.map_err(PtyError::Open)?;
let mut cmd = CommandBuilder::new(&argv[0]);
if argv.len() > 1 {
cmd.args(&argv[1..]);
}
cmd.cwd(cwd);
let mut child = pair.slave.spawn_command(cmd).map_err(PtyError::Spawn)?;
drop(pair.slave);
let mut killer = child.clone_killer();
let mut reader = pair.master.try_clone_reader().map_err(PtyError::Reader)?;
let stop = Arc::new(Mutex::new(false));
let stop_reader = Arc::clone(&stop);
let reader_thread = thread::spawn(move || {
let stdout = io::stdout();
let mut buf = [0u8; 4096];
loop {
if stop_reader.lock().is_ok_and(|g| *g) {
break;
}
match reader.read(&mut buf) {
Ok(0) => break, Ok(n) => {
let mut out = stdout.lock();
if out.write_all(&buf[..n]).is_err() {
break;
}
let _ = out.flush();
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(_) => break, }
}
});
if let Some(sig) = exit_signal {
let _ = std::fs::remove_file(sig);
}
let poll_interval = Duration::from_millis(100);
let deadline = if timeout_secs > 0 {
Some(Instant::now() + Duration::from_secs(timeout_secs))
} else {
None
};
let mut exit_sent = false;
let exit_code = loop {
match child.try_wait() {
Ok(Some(status)) => break status.exit_code(),
Ok(None) => {
if let Some(dl) = deadline {
if Instant::now() >= dl {
#[cfg(unix)]
{
if let Some(pgid) = pair.master.process_group_leader() {
unsafe {
libc::killpg(pgid, libc::SIGKILL);
}
}
}
let _ = killer.kill(); thread::sleep(Duration::from_millis(500));
join_reader(reader_thread, &stop, pair.master);
return Err(PtyError::Timeout(timeout_secs));
}
}
if !exit_sent {
if let Some(sig) = exit_signal {
if sig.exists() {
send_exit_command(pair.master.as_ref(), sig);
exit_sent = true;
}
}
}
thread::sleep(poll_interval);
}
Err(e) => return Err(PtyError::Io(e)),
}
};
join_reader(reader_thread, &stop, pair.master);
Ok(RunResult { exit_code })
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn tmp() -> PathBuf {
std::env::temp_dir()
}
#[cfg(unix)]
#[test]
fn echo_hello_exit_zero() {
let result = run(&["echo".to_string(), "hello".to_string()], &tmp(), 10, None)
.expect("run should succeed");
assert_eq!(result.exit_code, 0, "echo should exit 0");
}
#[cfg(unix)]
#[test]
fn nonzero_exit_code_propagated() {
let result = run(&["false".to_string()], &tmp(), 10, None).expect("run should succeed");
assert_ne!(result.exit_code, 0, "false should exit non-zero");
}
#[cfg(unix)]
#[test]
fn timeout_fires() {
let err = run(&["sleep".to_string(), "60".to_string()], &tmp(), 1, None)
.expect_err("should time out");
match err {
PtyError::Timeout(secs) => assert_eq!(secs, 1),
other => panic!("expected Timeout, got {other:?}"),
}
}
#[cfg(unix)]
#[test]
fn exit_signal_triggers_child_exit() {
use std::fs;
let signal_path = tmp().join("test-exit-signal-trigger.tmp");
let _ = fs::remove_file(&signal_path);
let signal_path_clone = signal_path.clone();
let writer_thread = thread::spawn(move || {
thread::sleep(Duration::from_millis(300));
fs::write(&signal_path_clone, b"").expect("write signal file");
});
let result = run(
&["sh".to_string(), "-c".to_string(), "read line".to_string()],
&tmp(),
10, Some(&signal_path),
)
.expect("run should succeed");
writer_thread.join().expect("writer thread panicked");
assert_eq!(result.exit_code, 0, "child should exit 0 after signal");
assert!(
!signal_path.exists(),
"signal file should be deleted after /exit is sent"
);
}
#[cfg(unix)]
#[test]
fn stale_signal_file_deleted_before_poll() {
use std::fs;
let signal_path = tmp().join("test-exit-signal-stale.tmp");
fs::write(&signal_path, b"").expect("write stale signal file");
assert!(
signal_path.exists(),
"precondition: stale file should exist"
);
let result = run(
&["echo".to_string(), "hello".to_string()],
&tmp(),
10,
Some(&signal_path),
)
.expect("run should succeed");
assert_eq!(result.exit_code, 0);
assert!(
!signal_path.exists(),
"stale signal file should be deleted by run() before poll loop"
);
}
}