use std::io::{Read, Write};
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use portable_pty::{Child, CommandBuilder, MasterPty, NativePtySystem, PtySize, PtySystem};
use tracing::warn;
const REAP_DEADLINE: Duration = Duration::from_secs(2);
const REAP_POLL: Duration = Duration::from_millis(10);
pub struct PtyHandle {
master: Arc<Mutex<Box<dyn MasterPty + Send>>>,
writer: Arc<Mutex<Box<dyn Write + Send>>>,
child: Arc<Mutex<Option<Box<dyn Child + Send + Sync>>>>,
reader_join: Option<JoinHandle<()>>,
bytes_consumed: Arc<std::sync::atomic::AtomicU64>,
}
impl PtyHandle {
pub fn spawn(
shell: &str,
args: &[String],
cwd: Option<&str>,
env: &[(String, String)],
size: PtySize,
mut on_bytes: Box<dyn FnMut(&[u8]) + Send>,
on_exit: Box<dyn FnOnce(Option<i32>) + Send>,
) -> anyhow::Result<Self> {
let pty_system = NativePtySystem::default();
let pair = pty_system.openpty(size)?;
let mut cmd = CommandBuilder::new(shell);
for a in args {
cmd.arg(a);
}
if let Some(d) = cwd {
cmd.cwd(d);
}
for (k, v) in env {
cmd.env(k, v);
}
match cwd {
Some(d) => cmd.env("PWD", d),
None => cmd.env_remove("PWD"),
}
let child = pair.slave.spawn_command(cmd)?;
let child = Arc::new(Mutex::new(Some(child)));
let child_for_reader = Arc::clone(&child);
drop(pair.slave);
let mut reader = pair.master.try_clone_reader()?;
let writer = pair.master.take_writer()?;
let master = Arc::new(Mutex::new(pair.master));
let writer = Arc::new(Mutex::new(writer));
let bytes_consumed = Arc::new(std::sync::atomic::AtomicU64::new(0));
let bytes_consumed_for_thread = Arc::clone(&bytes_consumed);
let reader_join = std::thread::Builder::new()
.name("tear-pty-reader".into())
.spawn(move || {
let mut buf = vec![0u8; 64 * 1024];
loop {
match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
bytes_consumed_for_thread
.fetch_add(n as u64, std::sync::atomic::Ordering::Relaxed);
on_bytes(&buf[..n]);
}
Err(e) => {
warn!(error = %e, "tear pty reader error");
break;
}
}
}
let taken = child_for_reader.lock().take();
let code = taken
.and_then(|mut c| c.wait().ok())
.map(|status| status.exit_code() as i32);
on_exit(code);
})?;
Ok(Self {
master,
writer,
child,
reader_join: Some(reader_join),
bytes_consumed,
})
}
pub fn write(&self, bytes: &[u8]) -> std::io::Result<()> {
let mut w = self.writer.lock();
w.write_all(bytes)
}
pub fn resize(&self, size: PtySize) -> anyhow::Result<()> {
let m = self.master.lock();
m.resize(size)?;
Ok(())
}
pub fn bytes_consumed(&self) -> u64 {
self.bytes_consumed
.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Drop for PtyHandle {
fn drop(&mut self) {
if let Some(child) = self.child.lock().take() {
reap_with_deadline(child);
}
if let Some(j) = self.reader_join.take() {
drop(j);
}
}
}
fn reap_with_deadline(mut child: Box<dyn Child + Send + Sync>) {
if let Err(e) = child.kill() {
warn!(error = %e, "tear pty child kill failed (already exited?)");
}
let deadline = Instant::now() + REAP_DEADLINE;
loop {
match child.try_wait() {
Ok(Some(_)) => return,
Err(_) => return,
Ok(None) => {
if Instant::now() >= deadline {
break;
}
std::thread::sleep(REAP_POLL);
}
}
}
warn!("tear pty child survived kill past reap deadline — detaching reaper thread");
let _ = std::thread::Builder::new()
.name("tear-pty-reaper".into())
.spawn(move || {
let _ = child.wait();
});
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc;
#[test]
fn spawn_with_cwd_sets_env_pwd_consistent() {
let dir = std::env::temp_dir();
let dir_str = dir.to_string_lossy().into_owned();
let (tx, rx) = mpsc::channel::<Vec<u8>>();
let _handle = PtyHandle::spawn(
"/bin/sh",
&["-c".into(), "printf 'PWDCHK[%s]\\n' \"$PWD\"".into()],
Some(&dir_str),
&[("PWD".into(), "/stale/parent".into())],
PtySize { rows: 24, cols: 80, pixel_width: 0, pixel_height: 0 },
Box::new(move |b| {
let _ = tx.send(b.to_vec());
}),
Box::new(|_| {}),
)
.expect("spawn /bin/sh");
let mut buf = Vec::new();
let deadline = Instant::now() + Duration::from_secs(2);
while Instant::now() < deadline {
if let Ok(chunk) = rx.recv_timeout(Duration::from_millis(100)) {
buf.extend_from_slice(&chunk);
if std::str::from_utf8(&buf).map(|s| s.contains("PWDCHK[")).unwrap_or(false) {
break;
}
}
}
let text = String::from_utf8_lossy(&buf);
assert!(
text.contains(&format!("PWDCHK[{}", dir_str.trim_end_matches('/'))),
"child $PWD must match the spawn cwd, not the stale parent PWD: {text:?}"
);
assert!(
!text.contains("/stale/parent"),
"the stale parent PWD leaked to the child: {text:?}"
);
}
#[test]
fn spawn_without_cwd_never_leaks_parent_pwd() {
let (tx, rx) = mpsc::channel::<Vec<u8>>();
let _handle = PtyHandle::spawn(
"/bin/sh",
&["-c".into(), "printf 'PWDCHK[%s]\\n' \"${PWD:-UNSET}\"".into()],
None,
&[("PWD".into(), "/stale/parent".into())],
PtySize { rows: 24, cols: 80, pixel_width: 0, pixel_height: 0 },
Box::new(move |b| {
let _ = tx.send(b.to_vec());
}),
Box::new(|_| {}),
)
.expect("spawn /bin/sh");
let mut buf = Vec::new();
let deadline = Instant::now() + Duration::from_secs(2);
while Instant::now() < deadline {
if let Ok(chunk) = rx.recv_timeout(Duration::from_millis(100)) {
buf.extend_from_slice(&chunk);
if std::str::from_utf8(&buf).map(|s| s.contains("PWDCHK[")).unwrap_or(false) {
break;
}
}
}
let text = String::from_utf8_lossy(&buf);
assert!(text.contains("PWDCHK["), "no PWDCHK output: {text:?}");
assert!(
!text.contains("/stale/parent"),
"no-cwd spawn must strip the inherited PWD; the stale parent leaked: {text:?}"
);
}
#[test]
fn drop_reaps_sighup_immune_child_within_deadline() {
let dir = std::env::temp_dir();
let handle = PtyHandle::spawn(
"/usr/bin/nohup",
&["cat".into()],
dir.to_str(),
&[("PATH".into(), "/usr/bin:/bin".into())],
PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
},
Box::new(|_| {}),
Box::new(|_| {}),
)
.expect("spawn nohup cat");
std::thread::sleep(Duration::from_millis(200));
let started = Instant::now();
drop(handle);
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_secs(5),
"PtyHandle drop blocked {elapsed:?} on a SIGHUP-immune child — reap is unbounded again"
);
}
}