use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct PtyConfig {
pub rows: u16,
pub cols: u16,
pub cmd: String,
pub args: Vec<String>,
pub cwd: Option<PathBuf>,
pub env: HashMap<String, String>,
}
impl Default for PtyConfig {
fn default() -> Self {
Self {
rows: 24,
cols: 80,
cmd: String::new(),
args: Vec::new(),
cwd: None,
env: HashMap::new(),
}
}
}
impl PtyConfig {
#[must_use]
pub fn new(cmd: impl Into<String>) -> Self {
Self {
cmd: cmd.into(),
..Default::default()
}
}
#[must_use]
pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.args = args.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
self.cwd = Some(cwd.into());
self
}
#[must_use]
pub fn size(mut self, rows: u16, cols: u16) -> Self {
self.rows = rows;
self.cols = cols;
self
}
}
pub trait PtyChild: Send {
fn try_wait(&mut self) -> io::Result<Option<i32>>;
fn kill(&mut self) -> io::Result<()>;
fn resize(&mut self, rows: u16, cols: u16) -> io::Result<()>;
}
pub type SpawnResult<R, W> = io::Result<(R, W, Box<dyn PtyChild>)>;
pub trait PtyHandle: Send + Sync {
type Reader: Read + Send + 'static;
type Writer: Write + Send + 'static;
fn spawn(&self, config: &PtyConfig) -> SpawnResult<Self::Reader, Self::Writer>;
fn resize(&self, rows: u16, cols: u16) -> io::Result<()>;
}
#[cfg(test)]
pub mod mock {
use super::*;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct MockPty {
output: Arc<Mutex<VecDeque<u8>>>,
input: Arc<Mutex<Vec<u8>>>,
exit_code: Arc<Mutex<Option<i32>>>,
size: Arc<Mutex<(u16, u16)>>,
closed: Arc<Mutex<bool>>,
}
impl Default for MockPty {
fn default() -> Self {
Self {
output: Arc::default(),
input: Arc::default(),
exit_code: Arc::default(),
size: Arc::new(Mutex::new((24, 80))),
closed: Arc::new(Mutex::new(false)),
}
}
}
impl MockPty {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn inject_output(&self, data: &[u8]) {
if let Ok(mut output) = self.output.lock() {
output.extend(data);
}
}
#[must_use]
pub fn captured_input(&self) -> Vec<u8> {
self.input.lock().map_or_else(|_| Vec::new(), |i| i.clone())
}
pub fn set_exit_code(&self, code: i32) {
if let Ok(mut exit_code) = self.exit_code.lock() {
*exit_code = Some(code);
}
}
#[must_use]
pub fn current_size(&self) -> (u16, u16) {
self.size.lock().map_or((24, 80), |s| *s)
}
pub fn close(&self) {
if let Ok(mut closed) = self.closed.lock() {
*closed = true;
}
}
}
pub struct MockReader {
output: Arc<Mutex<VecDeque<u8>>>,
closed: Arc<Mutex<bool>>,
}
impl MockReader {
fn new(output: Arc<Mutex<VecDeque<u8>>>, closed: Arc<Mutex<bool>>) -> Self {
Self { output, closed }
}
}
impl Read for MockReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
if let Ok(closed) = self.closed.lock()
&& *closed
{
return Ok(0);
}
if let Ok(mut output) = self.output.lock()
&& !output.is_empty()
{
let len = buf.len().min(output.len());
for (i, b) in output.drain(..len).enumerate() {
buf[i] = b;
}
return Ok(len);
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
}
}
pub struct MockWriter(Arc<Mutex<Vec<u8>>>);
impl Write for MockWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut input = self
.0
.lock()
.map_err(|_| io::Error::other("lock poisoned"))?;
input.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct MockChild {
exit_code: Arc<Mutex<Option<i32>>>,
size: Arc<Mutex<(u16, u16)>>,
}
impl PtyChild for MockChild {
fn try_wait(&mut self) -> io::Result<Option<i32>> {
Ok(self.exit_code.lock().map_or(None, |e| *e))
}
fn kill(&mut self) -> io::Result<()> {
if let Ok(mut exit_code) = self.exit_code.lock()
&& exit_code.is_none()
{
*exit_code = Some(-1);
}
Ok(())
}
fn resize(&mut self, rows: u16, cols: u16) -> io::Result<()> {
if let Ok(mut size) = self.size.lock() {
*size = (rows, cols);
}
Ok(())
}
}
impl PtyHandle for MockPty {
type Reader = MockReader;
type Writer = MockWriter;
fn spawn(&self, _config: &PtyConfig) -> SpawnResult<Self::Reader, Self::Writer> {
Ok((
MockReader::new(Arc::clone(&self.output), Arc::clone(&self.closed)),
MockWriter(Arc::clone(&self.input)),
Box::new(MockChild {
exit_code: Arc::clone(&self.exit_code),
size: Arc::clone(&self.size),
}),
))
}
fn resize(&self, rows: u16, cols: u16) -> io::Result<()> {
if let Ok(mut size) = self.size.lock() {
*size = (rows, cols);
}
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn mock_pty_spawn() {
let pty = MockPty::new();
let config = PtyConfig::new("test");
let result = pty.spawn(&config);
assert!(result.is_ok());
}
#[test]
fn mock_pty_io() {
let pty = MockPty::new();
let config = PtyConfig::new("test");
let (mut reader, mut writer, _child) = pty.spawn(&config).expect("spawn failed");
pty.inject_output(b"Hello from PTY");
let mut buf = [0u8; 64];
let n = reader.read(&mut buf).expect("read failed");
assert_eq!(&buf[..n], b"Hello from PTY");
writer.write_all(b"Input data").expect("write failed");
assert_eq!(pty.captured_input(), b"Input data");
}
#[test]
fn mock_child_exit() {
let pty = MockPty::new();
let config = PtyConfig::new("test");
let (_reader, _writer, mut child) = pty.spawn(&config).expect("spawn failed");
assert_eq!(child.try_wait().expect("try_wait failed"), None);
pty.set_exit_code(42);
assert_eq!(child.try_wait().expect("try_wait failed"), Some(42));
}
#[test]
fn mock_child_kill() {
let pty = MockPty::new();
let config = PtyConfig::new("test");
let (_reader, _writer, mut child) = pty.spawn(&config).expect("spawn failed");
child.kill().expect("kill failed");
assert_eq!(child.try_wait().expect("try_wait failed"), Some(-1));
}
#[test]
fn mock_pty_resize() {
let pty = MockPty::new();
assert_eq!(pty.current_size(), (24, 80));
pty.resize(50, 120).expect("resize failed");
assert_eq!(pty.current_size(), (50, 120));
}
}
}
pub mod native {
use super::{PtyChild, PtyConfig, PtyHandle, Read, SpawnResult, Write, io};
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
pub struct NativePty {
system: Box<dyn portable_pty::PtySystem + Send>,
}
unsafe impl Sync for NativePty {}
impl NativePty {
#[must_use]
pub fn new() -> Self {
Self {
system: native_pty_system(),
}
}
}
impl Default for NativePty {
fn default() -> Self {
Self::new()
}
}
pub struct NativeReader {
reader: Box<dyn Read + Send>,
}
impl Read for NativeReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.reader.read(buf)
}
}
pub struct NativeWriter {
writer: Box<dyn Write + Send>,
}
impl Write for NativeWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.writer.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.writer.flush()
}
}
pub struct NativeChild {
child: Box<dyn portable_pty::Child + Send>,
master: Box<dyn portable_pty::MasterPty + Send>,
}
impl PtyChild for NativeChild {
fn try_wait(&mut self) -> io::Result<Option<i32>> {
match self.child.try_wait() {
Ok(Some(status)) => {
let raw = status.exit_code();
let code: i32 = raw.try_into().map_err(|_| {
io::Error::other(format!("exit code {raw} out of i32 range"))
})?;
Ok(Some(code))
}
Ok(None) => Ok(None),
Err(e) => Err(io::Error::other(e.to_string())),
}
}
fn kill(&mut self) -> io::Result<()> {
self.child
.kill()
.map_err(|e| io::Error::other(e.to_string()))
}
fn resize(&mut self, rows: u16, cols: u16) -> io::Result<()> {
let size = PtySize {
rows,
cols,
pixel_width: 0,
pixel_height: 0,
};
self.master
.resize(size)
.map_err(|e| io::Error::other(e.to_string()))
}
}
impl PtyHandle for NativePty {
type Reader = NativeReader;
type Writer = NativeWriter;
fn spawn(&self, config: &PtyConfig) -> SpawnResult<Self::Reader, Self::Writer> {
let size = PtySize {
rows: config.rows,
cols: config.cols,
pixel_width: 0,
pixel_height: 0,
};
let pair = self
.system
.openpty(size)
.map_err(|e| io::Error::other(e.to_string()))?;
let mut cmd = CommandBuilder::new(&config.cmd);
cmd.args(&config.args);
if let Some(ref cwd) = config.cwd {
cmd.cwd(cwd);
}
for (key, value) in &config.env {
cmd.env(key, value);
}
let child = pair
.slave
.spawn_command(cmd)
.map_err(|e| io::Error::other(e.to_string()))?;
let reader = pair
.master
.try_clone_reader()
.map_err(|e| io::Error::other(e.to_string()))?;
let writer = pair
.master
.take_writer()
.map_err(|e| io::Error::other(e.to_string()))?;
Ok((
NativeReader { reader },
NativeWriter { writer },
Box::new(NativeChild {
child,
master: pair.master,
}),
))
}
fn resize(&self, _rows: u16, _cols: u16) -> io::Result<()> {
Ok(())
}
}
}
pub use native::NativePty;
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn pty_config_builder() {
let config = PtyConfig::new("claude")
.args(["--model", "opus"])
.cwd("/tmp")
.size(40, 120);
assert_eq!(config.cmd, "claude");
assert_eq!(config.args, vec!["--model", "opus"]);
assert_eq!(config.cwd, Some(PathBuf::from("/tmp")));
assert_eq!(config.rows, 40);
assert_eq!(config.cols, 120);
}
#[test]
fn pty_config_default() {
let config = PtyConfig::default();
assert_eq!(config.rows, 24);
assert_eq!(config.cols, 80);
assert!(config.cmd.is_empty());
}
}