use std::collections::HashMap;
use std::fmt;
use std::io::{self, Read, Write};
use std::path::PathBuf;
use std::sync::mpsc;
use std::thread;
use std::time::{Duration, Instant};
use portable_pty::{CommandBuilder, ExitStatus, PtySize};
#[derive(Debug, Clone)]
pub struct ShellConfig {
pub shell: Option<PathBuf>,
pub args: Vec<String>,
pub env: HashMap<String, String>,
pub cwd: Option<PathBuf>,
pub cols: u16,
pub rows: u16,
pub term: String,
pub log_events: bool,
}
impl Default for ShellConfig {
fn default() -> Self {
Self {
shell: None,
args: Vec::new(),
env: HashMap::new(),
cwd: None,
cols: 80,
rows: 24,
term: "xterm-256color".to_string(),
log_events: false,
}
}
}
impl ShellConfig {
#[must_use]
pub fn with_shell(shell: impl Into<PathBuf>) -> Self {
Self {
shell: Some(shell.into()),
..Default::default()
}
}
#[must_use]
pub fn size(mut self, cols: u16, rows: u16) -> Self {
self.cols = cols;
self.rows = rows;
self
}
#[must_use]
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
#[must_use]
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
#[must_use]
pub fn inherit_env(mut self) -> Self {
for (key, value) in std::env::vars() {
self.env.entry(key).or_insert(value);
}
self
}
#[must_use]
pub fn cwd(mut self, path: impl Into<PathBuf>) -> Self {
self.cwd = Some(path.into());
self
}
#[must_use]
pub fn term(mut self, term: impl Into<String>) -> Self {
self.term = term.into();
self
}
#[must_use]
pub fn logging(mut self, enabled: bool) -> Self {
self.log_events = enabled;
self
}
fn resolve_shell(&self) -> PathBuf {
if let Some(ref shell) = self.shell {
return shell.clone();
}
if let Ok(shell) = std::env::var("SHELL") {
return PathBuf::from(shell);
}
PathBuf::from("/bin/sh")
}
}
#[derive(Debug)]
enum ReaderMsg {
Data(Vec<u8>),
Eof,
Err(io::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProcessState {
Running,
Exited(i32),
Signaled(i32),
Unknown,
}
impl ProcessState {
#[must_use]
pub const fn is_alive(self) -> bool {
matches!(self, ProcessState::Running)
}
#[must_use]
pub const fn exit_code(self) -> Option<i32> {
match self {
ProcessState::Exited(code) => Some(code),
_ => None,
}
}
}
pub struct PtyProcess {
child: Box<dyn portable_pty::Child + Send + Sync>,
writer: Box<dyn Write + Send>,
rx: mpsc::Receiver<ReaderMsg>,
reader_thread: Option<thread::JoinHandle<()>>,
captured: Vec<u8>,
eof: bool,
state: ProcessState,
config: ShellConfig,
}
impl fmt::Debug for PtyProcess {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PtyProcess")
.field("pid", &self.child.process_id())
.field("state", &self.state)
.field("captured_len", &self.captured.len())
.field("eof", &self.eof)
.finish()
}
}
impl PtyProcess {
pub fn spawn(config: ShellConfig) -> io::Result<Self> {
let shell_path = config.resolve_shell();
if config.log_events {
log_event(
"PTY_PROCESS_SPAWN",
format!("shell={}", shell_path.display()),
);
}
let mut cmd = CommandBuilder::new(&shell_path);
for arg in &config.args {
cmd.arg(arg);
}
cmd.env("TERM", &config.term);
for (key, value) in &config.env {
cmd.env(key, value);
}
if let Some(ref cwd) = config.cwd {
cmd.cwd(cwd);
}
let pty_system = portable_pty::native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows: config.rows,
cols: config.cols,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| io::Error::other(e.to_string()))?;
let child = pair
.slave
.spawn_command(cmd)
.map_err(|e| io::Error::other(e.to_string()))?;
let mut reader = pair
.master
.try_clone_reader()
.map_err(|e| io::Error::other(e.to_string()))?;
let writer = pair
.master
.take_writer()
.map_err(|e| io::Error::other(e.to_string()))?;
let (tx, rx) = mpsc::channel::<ReaderMsg>();
let reader_thread = thread::spawn(move || {
let mut buf = [0u8; 8192];
loop {
match reader.read(&mut buf) {
Ok(0) => {
let _ = tx.send(ReaderMsg::Eof);
break;
}
Ok(n) => {
let _ = tx.send(ReaderMsg::Data(buf[..n].to_vec()));
}
Err(err) => {
let _ = tx.send(ReaderMsg::Err(err));
break;
}
}
}
});
if config.log_events {
log_event(
"PTY_PROCESS_STARTED",
format!("pid={:?}", child.process_id()),
);
}
Ok(Self {
child,
writer,
rx,
reader_thread: Some(reader_thread),
captured: Vec::new(),
eof: false,
state: ProcessState::Running,
config,
})
}
#[must_use]
pub fn is_alive(&mut self) -> bool {
self.poll_state();
self.state.is_alive()
}
#[must_use]
pub fn state(&mut self) -> ProcessState {
self.poll_state();
self.state
}
#[must_use]
pub fn pid(&self) -> Option<u32> {
self.child.process_id()
}
pub fn kill(&mut self) -> io::Result<()> {
if !self.state.is_alive() {
return Ok(());
}
if self.config.log_events {
log_event(
"PTY_PROCESS_KILL",
format!("pid={:?}", self.child.process_id()),
);
}
self.child.kill()?;
self.state = ProcessState::Unknown;
match self.wait_timeout(Duration::from_millis(100)) {
Ok(status) => {
self.update_state_from_exit(&status);
}
Err(_) => {
self.state = ProcessState::Unknown;
}
}
Ok(())
}
pub fn wait(&mut self) -> io::Result<ExitStatus> {
let status = self.child.wait()?;
self.update_state_from_exit(&status);
Ok(status)
}
pub fn wait_timeout(&mut self, timeout: Duration) -> io::Result<ExitStatus> {
let deadline = Instant::now() + timeout;
loop {
match self.child.try_wait()? {
Some(status) => {
self.update_state_from_exit(&status);
return Ok(status);
}
None => {
if Instant::now() >= deadline {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"wait_timeout: process did not exit in time",
));
}
thread::sleep(Duration::from_millis(10));
}
}
}
}
pub fn write_all(&mut self, data: &[u8]) -> io::Result<()> {
self.writer.write_all(data)?;
self.writer.flush()?;
if self.config.log_events {
log_event("PTY_PROCESS_INPUT", format!("bytes={}", data.len()));
}
Ok(())
}
pub fn read_available(&mut self) -> io::Result<Vec<u8>> {
self.drain_channel(Duration::ZERO)?;
Ok(self.captured.clone())
}
pub fn read_until(&mut self, pattern: &[u8], timeout: Duration) -> io::Result<Vec<u8>> {
if pattern.is_empty() {
return Ok(self.captured.clone());
}
let deadline = Instant::now() + timeout;
loop {
if find_subsequence(&self.captured, pattern).is_some() {
return Ok(self.captured.clone());
}
if self.eof || Instant::now() >= deadline {
break;
}
let remaining = deadline.saturating_duration_since(Instant::now());
self.drain_channel(remaining)?;
}
Err(io::Error::new(
io::ErrorKind::TimedOut,
format!(
"read_until: pattern not found (captured {} bytes)",
self.captured.len()
),
))
}
pub fn drain(&mut self, timeout: Duration) -> io::Result<usize> {
if self.eof {
return Ok(0);
}
let start_len = self.captured.len();
let deadline = Instant::now() + timeout;
while !self.eof && Instant::now() < deadline {
let remaining = deadline.saturating_duration_since(Instant::now());
match self.drain_channel(remaining) {
Ok(0) if self.eof => break,
Ok(_) => continue,
Err(e) if e.kind() == io::ErrorKind::TimedOut => break,
Err(e) => return Err(e),
}
}
Ok(self.captured.len() - start_len)
}
#[must_use]
pub fn output(&self) -> &[u8] {
&self.captured
}
pub fn clear_output(&mut self) {
self.captured.clear();
}
pub fn resize(&mut self, cols: u16, rows: u16) -> io::Result<()> {
if self.config.log_events {
log_event("PTY_PROCESS_RESIZE", format!("cols={} rows={}", cols, rows));
}
Ok(())
}
fn poll_state(&mut self) {
if !self.state.is_alive() {
return;
}
match self.child.try_wait() {
Ok(Some(status)) => {
self.update_state_from_exit(&status);
}
Ok(None) => {
}
Err(_) => {
self.state = ProcessState::Unknown;
}
}
}
fn update_state_from_exit(&mut self, status: &ExitStatus) {
if status.success() {
self.state = ProcessState::Exited(0);
} else {
let code = 1; self.state = ProcessState::Exited(code);
}
}
fn drain_channel(&mut self, timeout: Duration) -> io::Result<usize> {
if self.eof {
return Ok(0);
}
let mut total = 0usize;
let first = if timeout.is_zero() {
match self.rx.try_recv() {
Ok(msg) => Some(msg),
Err(mpsc::TryRecvError::Empty) => return Ok(0),
Err(mpsc::TryRecvError::Disconnected) => {
self.eof = true;
return Ok(0);
}
}
} else {
match self.rx.recv_timeout(timeout) {
Ok(msg) => Some(msg),
Err(mpsc::RecvTimeoutError::Timeout) => return Ok(0),
Err(mpsc::RecvTimeoutError::Disconnected) => {
self.eof = true;
return Ok(0);
}
}
};
let mut msg = match first {
Some(m) => m,
None => return Ok(0),
};
loop {
match msg {
ReaderMsg::Data(bytes) => {
total = total.saturating_add(bytes.len());
self.captured.extend_from_slice(&bytes);
}
ReaderMsg::Eof => {
self.eof = true;
break;
}
ReaderMsg::Err(err) => return Err(err),
}
match self.rx.try_recv() {
Ok(next) => msg = next,
Err(mpsc::TryRecvError::Empty) => break,
Err(mpsc::TryRecvError::Disconnected) => {
self.eof = true;
break;
}
}
}
if total > 0 && self.config.log_events {
log_event("PTY_PROCESS_OUTPUT", format!("bytes={}", total));
}
Ok(total)
}
}
impl Drop for PtyProcess {
fn drop(&mut self) {
let _ = self.writer.flush();
let _ = self.child.kill();
if let Some(handle) = self.reader_thread.take() {
let _ = handle.join();
}
if self.config.log_events {
log_event(
"PTY_PROCESS_DROP",
format!("pid={:?}", self.child.process_id()),
);
}
}
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() {
return Some(0);
}
haystack
.windows(needle.len())
.position(|window| window == needle)
}
fn log_event(event: &str, detail: impl fmt::Display) {
let timestamp = time::OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string());
eprintln!("[{}] {}: {}", timestamp, event, detail);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shell_config_defaults() {
let config = ShellConfig::default();
assert!(config.shell.is_none());
assert!(config.args.is_empty());
assert!(config.env.is_empty());
assert!(config.cwd.is_none());
assert_eq!(config.cols, 80);
assert_eq!(config.rows, 24);
assert_eq!(config.term, "xterm-256color");
assert!(!config.log_events);
}
#[test]
fn shell_config_with_shell() {
let config = ShellConfig::with_shell("/bin/bash");
assert_eq!(config.shell, Some(PathBuf::from("/bin/bash")));
}
#[test]
fn shell_config_builder_chain() {
let config = ShellConfig::default()
.size(120, 40)
.arg("-l")
.env("FOO", "bar")
.cwd("/tmp")
.term("dumb")
.logging(true);
assert_eq!(config.cols, 120);
assert_eq!(config.rows, 40);
assert_eq!(config.args, vec!["-l"]);
assert_eq!(config.env.get("FOO"), Some(&"bar".to_string()));
assert_eq!(config.cwd, Some(PathBuf::from("/tmp")));
assert_eq!(config.term, "dumb");
assert!(config.log_events);
}
#[test]
fn shell_config_resolve_shell_explicit() {
let config = ShellConfig::with_shell("/bin/zsh");
assert_eq!(config.resolve_shell(), PathBuf::from("/bin/zsh"));
}
#[test]
fn shell_config_resolve_shell_env() {
let config = ShellConfig::default();
let shell = config.resolve_shell();
assert!(shell.to_str().unwrap().contains("sh") || shell.to_str().unwrap().contains("zsh"));
}
#[test]
fn process_state_is_alive() {
assert!(ProcessState::Running.is_alive());
assert!(!ProcessState::Exited(0).is_alive());
assert!(!ProcessState::Signaled(9).is_alive());
assert!(!ProcessState::Unknown.is_alive());
}
#[test]
fn process_state_exit_code() {
assert_eq!(ProcessState::Running.exit_code(), None);
assert_eq!(ProcessState::Exited(0).exit_code(), Some(0));
assert_eq!(ProcessState::Exited(1).exit_code(), Some(1));
assert_eq!(ProcessState::Signaled(9).exit_code(), None);
assert_eq!(ProcessState::Unknown.exit_code(), None);
}
#[test]
fn find_subsequence_empty_needle() {
assert_eq!(find_subsequence(b"anything", b""), Some(0));
}
#[test]
fn find_subsequence_found() {
assert_eq!(find_subsequence(b"hello world", b"world"), Some(6));
}
#[test]
fn find_subsequence_not_found() {
assert_eq!(find_subsequence(b"hello world", b"xyz"), None);
}
#[cfg(unix)]
#[test]
fn spawn_and_basic_io() {
let config = ShellConfig::default().logging(false);
let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
assert!(proc.is_alive());
assert!(proc.pid().is_some());
proc.write_all(b"echo hello-pty-process\n")
.expect("write should succeed");
let output = proc
.read_until(b"hello-pty-process", Duration::from_secs(5))
.expect("should find output");
assert!(
output
.windows(b"hello-pty-process".len())
.any(|w| w == b"hello-pty-process"),
"expected to find 'hello-pty-process' in output"
);
proc.kill().expect("kill should succeed");
assert!(!proc.is_alive());
}
#[cfg(unix)]
#[test]
fn spawn_with_env() {
let config = ShellConfig::default()
.logging(false)
.env("TEST_VAR", "test_value_123");
let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
proc.write_all(b"echo $TEST_VAR\n")
.expect("write should succeed");
let output = proc
.read_until(b"test_value_123", Duration::from_secs(5))
.expect("should find env var in output");
assert!(
output
.windows(b"test_value_123".len())
.any(|w| w == b"test_value_123"),
"expected to find env var value in output"
);
proc.kill().expect("kill should succeed");
}
#[cfg(unix)]
#[test]
fn exit_command_terminates() {
let config = ShellConfig::default().logging(false);
let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
proc.write_all(b"exit 0\n").expect("write should succeed");
let status = proc
.wait_timeout(Duration::from_secs(5))
.expect("wait should succeed");
assert!(status.success());
assert!(!proc.is_alive());
}
#[cfg(unix)]
#[test]
fn kill_is_idempotent() {
let config = ShellConfig::default().logging(false);
let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
proc.kill().expect("first kill should succeed");
proc.kill().expect("second kill should succeed");
proc.kill().expect("third kill should succeed");
assert!(!proc.is_alive());
}
#[cfg(unix)]
#[test]
fn drain_captures_all_output() {
let config = ShellConfig::default().logging(false);
let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
proc.write_all(b"for i in 1 2 3 4 5; do echo line$i; done; exit 0\n")
.expect("write should succeed");
let _ = proc.wait_timeout(Duration::from_secs(5));
let _ = proc.drain(Duration::from_secs(2));
let output = String::from_utf8_lossy(proc.output());
for i in 1..=5 {
assert!(
output.contains(&format!("line{i}")),
"missing line{i} in output: {output:?}"
);
}
}
#[cfg(unix)]
#[test]
fn clear_output_works() {
let config = ShellConfig::default().logging(false);
let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
proc.write_all(b"echo test\n")
.expect("write should succeed");
thread::sleep(Duration::from_millis(100));
let _ = proc.read_available();
assert!(!proc.output().is_empty());
proc.clear_output();
assert!(proc.output().is_empty());
proc.kill().expect("kill should succeed");
}
#[cfg(unix)]
#[test]
fn specific_shell_path() {
let config = ShellConfig::with_shell("/bin/sh").logging(false);
let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
assert!(proc.is_alive());
proc.kill().expect("kill should succeed");
}
#[cfg(unix)]
#[test]
fn invalid_shell_fails() {
let config = ShellConfig::with_shell("/nonexistent/shell").logging(false);
let result = PtyProcess::spawn(config);
assert!(result.is_err());
}
}