use std::collections::HashMap;
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use portable_pty::{CommandBuilder, MasterPty, PtySize, native_pty_system, Child, ChildKiller};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::{Result, RuntimeError};
pub struct PtyHandle {
master: Box<dyn MasterPty + Send>,
writer: Box<dyn Write + Send>,
_reader_task: JoinHandle<()>,
output_rx: mpsc::UnboundedReceiver<Vec<u8>>,
child: Box<dyn Child + Send + Sync>,
alive: Arc<AtomicBool>,
killer: Box<dyn ChildKiller + Send + Sync>,
}
impl PtyHandle {
pub fn spawn(
command: &str,
working_dir: Option<&str>,
env: HashMap<String, String>,
rows: u16,
cols: u16,
) -> Result<Self> {
let pty_system = native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows,
cols,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| RuntimeError::Tool(format!("Failed to open PTY: {e}")))?;
let parts: Vec<&str> = command.split_whitespace().collect();
let program = parts
.first()
.ok_or_else(|| RuntimeError::Tool("Empty command string".to_string()))?;
let mut cmd = CommandBuilder::new(program);
for arg in parts.iter().skip(1) {
cmd.arg(arg);
}
if let Some(dir) = working_dir {
cmd.cwd(dir);
}
cmd.env("TERM", "xterm-256color");
for (k, v) in &env {
cmd.env(k, v);
}
let child = pair
.slave
.spawn_command(cmd)
.map_err(|e| RuntimeError::Tool(format!("Failed to spawn command: {e}")))?;
drop(pair.slave);
let writer = pair
.master
.take_writer()
.map_err(|e| RuntimeError::Tool(format!("Failed to take PTY writer: {e}")))?;
let mut reader = pair
.master
.try_clone_reader()
.map_err(|e| RuntimeError::Tool(format!("Failed to clone PTY reader: {e}")))?;
let (output_tx, output_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let alive = Arc::new(AtomicBool::new(true));
let reader_alive = alive.clone();
let reader_task = tokio::task::spawn_blocking(move || {
let mut buf = [0u8; 4096];
loop {
match reader.read(&mut buf) {
Ok(0) => {
break;
}
Ok(n) => {
if output_tx.send(buf[..n].to_vec()).is_err() {
break;
}
}
Err(_) => {
break;
}
}
}
reader_alive.store(false, Ordering::SeqCst);
});
let killer = child.clone_killer();
Ok(PtyHandle {
master: pair.master,
writer,
_reader_task: reader_task,
output_rx,
child,
alive,
killer,
})
}
pub fn write(&mut self, input: &[u8]) -> Result<()> {
self.writer
.write_all(input)
.map_err(|e| RuntimeError::Tool(format!("PTY write failed: {e}")))?;
self.writer
.flush()
.map_err(|e| RuntimeError::Tool(format!("PTY flush failed: {e}")))?;
Ok(())
}
pub async fn try_read_output(&mut self, timeout: Duration) -> Vec<u8> {
let mut collected = Vec::new();
while let Ok(chunk) = self.output_rx.try_recv() {
collected.extend_from_slice(&chunk);
}
if collected.is_empty() {
match tokio::time::timeout(timeout, self.output_rx.recv()).await {
Ok(Some(chunk)) => {
collected.extend_from_slice(&chunk);
}
Ok(None) | Err(_) => {
return collected;
}
}
while let Ok(chunk) = self.output_rx.try_recv() {
collected.extend_from_slice(&chunk);
}
}
collected
}
pub fn resize(&self, rows: u16, cols: u16) -> Result<()> {
self.master
.resize(PtySize {
rows,
cols,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| RuntimeError::Tool(format!("PTY resize failed: {e}")))
}
pub fn is_alive(&mut self) -> bool {
if !self.alive.load(Ordering::SeqCst) {
return false;
}
match self.child.try_wait() {
Ok(Some(_status)) => {
self.alive.store(false, Ordering::SeqCst);
false
}
Ok(None) => true,
Err(_) => {
self.alive.store(false, Ordering::SeqCst);
false
}
}
}
}
impl Drop for PtyHandle {
fn drop(&mut self) {
if self.alive.load(Ordering::SeqCst) {
let _ = self.killer.kill();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[tokio::test]
async fn test_spawn_echo_hello() {
let mut handle = PtyHandle::spawn(
"echo hello",
None,
HashMap::new(),
24,
80,
)
.expect("failed to spawn echo");
let output = handle
.try_read_output(Duration::from_secs(3))
.await;
let text = String::from_utf8_lossy(&output);
assert!(
text.contains("hello"),
"expected 'hello' in output, got: {text:?}"
);
}
#[tokio::test]
async fn test_cat_echo_back() {
let mut handle = PtyHandle::spawn(
"cat",
None,
HashMap::new(),
24,
80,
)
.expect("failed to spawn cat");
handle.write(b"test\n").expect("write failed");
let output = handle
.try_read_output(Duration::from_secs(3))
.await;
let text = String::from_utf8_lossy(&output);
assert!(
text.contains("test"),
"expected 'test' in output, got: {text:?}"
);
}
#[tokio::test]
async fn test_exit_code_detection() {
let mut handle = PtyHandle::spawn(
"bash -c exit 42",
None,
HashMap::new(),
24,
80,
)
.expect("failed to spawn bash exit");
let _ = handle
.try_read_output(Duration::from_secs(3))
.await;
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(
!handle.is_alive(),
"expected process to have exited"
);
}
}