use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::config::SessionConfig;
use crate::error::{ExpectError, Result, SpawnError};
pub struct PtyTransport {
reader: Box<dyn AsyncRead + Unpin + Send>,
writer: Box<dyn AsyncWrite + Unpin + Send>,
pid: Option<u32>,
}
impl PtyTransport {
pub fn new<R, W>(reader: R, writer: W) -> Self
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
Self {
reader: Box::new(reader),
writer: Box::new(writer),
pid: None,
}
}
pub const fn set_pid(&mut self, pid: u32) {
self.pid = Some(pid);
}
#[must_use]
pub const fn pid(&self) -> Option<u32> {
self.pid
}
}
impl AsyncRead for PtyTransport {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.reader).poll_read(cx, buf)
}
}
impl AsyncWrite for PtyTransport {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.writer).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.writer).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.writer).poll_shutdown(cx)
}
}
#[derive(Debug, Clone)]
pub struct PtyConfig {
pub dimensions: (u16, u16),
pub login_shell: bool,
pub env_mode: EnvMode,
}
impl Default for PtyConfig {
fn default() -> Self {
Self {
dimensions: (80, 24),
login_shell: false,
env_mode: EnvMode::Inherit,
}
}
}
impl From<&SessionConfig> for PtyConfig {
fn from(config: &SessionConfig) -> Self {
Self {
dimensions: config.dimensions,
login_shell: false,
env_mode: if config.env.is_empty() {
EnvMode::Inherit
} else {
EnvMode::Extend
},
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnvMode {
Inherit,
Clear,
Extend,
}
pub struct PtySpawner {
config: PtyConfig,
}
impl PtySpawner {
#[must_use]
pub fn new() -> Self {
Self {
config: PtyConfig::default(),
}
}
#[must_use]
pub const fn with_config(config: PtyConfig) -> Self {
Self { config }
}
pub const fn set_dimensions(&mut self, cols: u16, rows: u16) {
self.config.dimensions = (cols, rows);
}
#[cfg(unix)]
#[allow(unsafe_code)]
#[allow(clippy::unused_async)]
pub async fn spawn(&self, command: &str, args: &[String]) -> Result<PtyHandle> {
use std::ffi::CString;
let cmd_cstring = CString::new(command).map_err(|_| {
ExpectError::Spawn(SpawnError::InvalidArgument {
kind: "command".to_string(),
value: command.to_string(),
reason: "command contains null byte".to_string(),
})
})?;
let mut argv_cstrings: Vec<CString> = Vec::with_capacity(args.len() + 1);
argv_cstrings.push(cmd_cstring.clone());
for (idx, arg) in args.iter().enumerate() {
let arg_cstring = CString::new(arg.as_str()).map_err(|_| {
ExpectError::Spawn(SpawnError::InvalidArgument {
kind: format!("argument[{idx}]"),
value: arg.clone(),
reason: "argument contains null byte".to_string(),
})
})?;
argv_cstrings.push(arg_cstring);
}
let pty_result = unsafe {
let mut master: libc::c_int = 0;
let mut slave: libc::c_int = 0;
if libc::openpty(
&raw mut master,
&raw mut slave,
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
) != 0
{
return Err(ExpectError::Spawn(SpawnError::PtyAllocation {
reason: "Failed to open PTY".to_string(),
}));
}
(master, slave)
};
let (master_fd, slave_fd) = pty_result;
let pid = unsafe { libc::fork() };
match pid {
-1 => Err(ExpectError::Spawn(SpawnError::Io(
io::Error::last_os_error(),
))),
0 => {
unsafe {
libc::close(master_fd);
libc::setsid();
libc::ioctl(slave_fd, libc::TIOCSCTTY as libc::c_ulong, 0);
libc::dup2(slave_fd, 0);
libc::dup2(slave_fd, 1);
libc::dup2(slave_fd, 2);
if slave_fd > 2 {
libc::close(slave_fd);
}
let argv_ptrs: Vec<*const libc::c_char> = argv_cstrings
.iter()
.map(|s| s.as_ptr())
.chain(std::iter::once(std::ptr::null()))
.collect();
libc::execvp(cmd_cstring.as_ptr(), argv_ptrs.as_ptr());
libc::_exit(1);
}
}
child_pid => {
unsafe {
libc::close(slave_fd);
}
unsafe {
let flags = libc::fcntl(master_fd, libc::F_GETFL);
libc::fcntl(master_fd, libc::F_SETFL, flags | libc::O_NONBLOCK);
}
Ok(PtyHandle {
master_fd,
pid: child_pid as u32,
dimensions: self.config.dimensions,
})
}
}
}
#[cfg(windows)]
pub async fn spawn(&self, command: &str, args: &[String]) -> Result<WindowsPtyHandle> {
use rust_pty::{PtySystem, WindowsPtySystem};
let pty_config = rust_pty::PtyConfig {
window_size: self.config.dimensions,
env: match self.config.env_mode {
EnvMode::Clear => Some(std::collections::HashMap::new()),
_ => None,
},
..Default::default()
};
let (master, child) =
WindowsPtySystem::spawn(command, args.iter().map(|s| s.as_str()), &pty_config)
.await
.map_err(|e| {
ExpectError::Spawn(SpawnError::PtyAllocation {
reason: format!("Windows ConPTY spawn failed: {e}"),
})
})?;
Ok(WindowsPtyHandle {
master,
child,
dimensions: self.config.dimensions,
})
}
}
impl Default for PtySpawner {
fn default() -> Self {
Self::new()
}
}
#[cfg(unix)]
#[derive(Debug)]
pub struct PtyHandle {
master_fd: i32,
pid: u32,
dimensions: (u16, u16),
}
#[cfg(windows)]
pub struct WindowsPtyHandle {
pub(crate) master: rust_pty::WindowsPtyMaster,
pub(crate) child: rust_pty::WindowsPtyChild,
dimensions: (u16, u16),
}
#[cfg(windows)]
impl std::fmt::Debug for WindowsPtyHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WindowsPtyHandle")
.field("dimensions", &self.dimensions)
.finish_non_exhaustive()
}
}
#[cfg(unix)]
impl PtyHandle {
#[must_use]
pub const fn pid(&self) -> u32 {
self.pid
}
#[must_use]
pub const fn dimensions(&self) -> (u16, u16) {
self.dimensions
}
#[allow(unsafe_code)]
pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
let winsize = libc::winsize {
ws_row: rows,
ws_col: cols,
ws_xpixel: 0,
ws_ypixel: 0,
};
let result =
unsafe { libc::ioctl(self.master_fd, libc::TIOCSWINSZ as libc::c_ulong, &winsize) };
if result != 0 {
Err(ExpectError::Io(io::Error::last_os_error()))
} else {
self.dimensions = (cols, rows);
Ok(())
}
}
#[allow(unsafe_code)]
pub fn wait(&self) -> Result<i32> {
let mut status: libc::c_int = 0;
let result = unsafe { libc::waitpid(self.pid as i32, &raw mut status, 0) };
if result == -1 {
Err(ExpectError::Io(io::Error::last_os_error()))
} else if libc::WIFEXITED(status) {
Ok(libc::WEXITSTATUS(status))
} else if libc::WIFSIGNALED(status) {
Ok(128 + libc::WTERMSIG(status))
} else {
Ok(-1)
}
}
#[allow(unsafe_code)]
pub fn signal(&self, signal: i32) -> Result<()> {
let result = unsafe { libc::kill(self.pid as i32, signal) };
if result != 0 {
Err(ExpectError::Io(io::Error::last_os_error()))
} else {
Ok(())
}
}
pub fn kill(&self) -> Result<()> {
self.signal(libc::SIGKILL)
}
}
#[cfg(windows)]
impl WindowsPtyHandle {
#[must_use]
pub fn pid(&self) -> u32 {
self.child.pid()
}
#[must_use]
pub const fn dimensions(&self) -> (u16, u16) {
self.dimensions
}
pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
use rust_pty::{PtyMaster, WindowSize};
let size = WindowSize::new(cols, rows);
self.master
.resize(size)
.map_err(|e| ExpectError::Io(io::Error::other(format!("resize failed: {e}"))))?;
self.dimensions = (cols, rows);
Ok(())
}
#[must_use]
pub fn is_running(&self) -> bool {
self.child.is_running()
}
pub fn kill(&mut self) -> Result<()> {
self.child
.kill()
.map_err(|e| ExpectError::Io(io::Error::other(format!("kill failed: {e}"))))
}
}
#[cfg(unix)]
impl Drop for PtyHandle {
#[allow(unsafe_code)]
fn drop(&mut self) {
unsafe {
libc::close(self.master_fd);
}
}
}
#[cfg(unix)]
pub struct AsyncPty {
inner: tokio::io::unix::AsyncFd<std::os::unix::io::RawFd>,
pid: u32,
dimensions: (u16, u16),
}
#[cfg(unix)]
impl AsyncPty {
pub fn from_handle(handle: PtyHandle) -> io::Result<Self> {
let fd = handle.master_fd;
let pid = handle.pid;
let dimensions = handle.dimensions;
std::mem::forget(handle);
let inner = tokio::io::unix::AsyncFd::new(fd)?;
Ok(Self {
inner,
pid,
dimensions,
})
}
#[must_use]
pub const fn pid(&self) -> u32 {
self.pid
}
#[must_use]
pub const fn dimensions(&self) -> (u16, u16) {
self.dimensions
}
#[allow(unsafe_code)]
pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
let winsize = libc::winsize {
ws_row: rows,
ws_col: cols,
ws_xpixel: 0,
ws_ypixel: 0,
};
let result = unsafe {
libc::ioctl(
*self.inner.get_ref(),
libc::TIOCSWINSZ as libc::c_ulong,
&winsize,
)
};
if result != 0 {
Err(ExpectError::Io(io::Error::last_os_error()))
} else {
self.dimensions = (cols, rows);
Ok(())
}
}
#[allow(unsafe_code)]
pub fn signal(&self, signal: i32) -> Result<()> {
let result = unsafe { libc::kill(self.pid as i32, signal) };
if result != 0 {
Err(ExpectError::Io(io::Error::last_os_error()))
} else {
Ok(())
}
}
pub fn kill(&self) -> Result<()> {
self.signal(libc::SIGKILL)
}
}
#[cfg(unix)]
impl AsyncRead for AsyncPty {
#[allow(unsafe_code)]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
let mut guard = match self.inner.poll_read_ready(cx) {
Poll::Ready(Ok(guard)) => guard,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
let fd = *self.inner.get_ref();
let unfilled = buf.initialize_unfilled();
let result = unsafe {
libc::read(
fd,
unfilled.as_mut_ptr().cast::<libc::c_void>(),
unfilled.len(),
)
};
if result >= 0 {
buf.advance(result as usize);
return Poll::Ready(Ok(()));
}
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock {
guard.clear_ready();
continue;
}
return Poll::Ready(Err(err));
}
}
}
#[cfg(unix)]
impl AsyncWrite for AsyncPty {
#[allow(unsafe_code)]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
let mut guard = match self.inner.poll_write_ready(cx) {
Poll::Ready(Ok(guard)) => guard,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
let fd = *self.inner.get_ref();
let result = unsafe { libc::write(fd, buf.as_ptr().cast::<libc::c_void>(), buf.len()) };
if result >= 0 {
return Poll::Ready(Ok(result as usize));
}
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock {
guard.clear_ready();
continue;
}
return Poll::Ready(Err(err));
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[cfg(unix)]
impl Drop for AsyncPty {
#[allow(unsafe_code)]
fn drop(&mut self) {
unsafe {
libc::close(*self.inner.get_ref());
}
}
}
#[cfg(unix)]
impl std::fmt::Debug for AsyncPty {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncPty")
.field("fd", self.inner.get_ref())
.field("pid", &self.pid)
.field("dimensions", &self.dimensions)
.finish()
}
}
#[cfg(windows)]
pub struct WindowsAsyncPty {
master: rust_pty::WindowsPtyMaster,
child: rust_pty::WindowsPtyChild,
pid: u32,
dimensions: (u16, u16),
}
#[cfg(windows)]
impl WindowsAsyncPty {
pub fn from_handle(handle: WindowsPtyHandle) -> Self {
let pid = handle.child.pid();
let dimensions = handle.dimensions;
Self {
master: handle.master,
child: handle.child,
pid,
dimensions,
}
}
#[must_use]
pub const fn pid(&self) -> u32 {
self.pid
}
#[must_use]
pub const fn dimensions(&self) -> (u16, u16) {
self.dimensions
}
pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
use rust_pty::{PtyMaster, WindowSize};
let size = WindowSize::new(cols, rows);
self.master
.resize(size)
.map_err(|e| ExpectError::Io(io::Error::other(format!("resize failed: {e}"))))?;
self.dimensions = (cols, rows);
Ok(())
}
#[must_use]
pub fn is_running(&self) -> bool {
self.child.is_running()
}
pub fn kill(&mut self) -> Result<()> {
self.child
.kill()
.map_err(|e| ExpectError::Io(io::Error::other(format!("kill failed: {e}"))))
}
}
#[cfg(windows)]
impl AsyncRead for WindowsAsyncPty {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.master).poll_read(cx, buf)
}
}
#[cfg(windows)]
impl AsyncWrite for WindowsAsyncPty {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.master).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.master).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.master).poll_shutdown(cx)
}
}
#[cfg(windows)]
impl std::fmt::Debug for WindowsAsyncPty {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WindowsAsyncPty")
.field("pid", &self.pid)
.field("dimensions", &self.dimensions)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pty_config_default() {
let config = PtyConfig::default();
assert_eq!(config.dimensions.0, 80);
assert_eq!(config.dimensions.1, 24);
assert_eq!(config.env_mode, EnvMode::Inherit);
}
#[test]
fn pty_config_from_session() {
let session_config = SessionConfig {
dimensions: (120, 40),
..Default::default()
};
let pty_config = PtyConfig::from(&session_config);
assert_eq!(pty_config.dimensions.0, 120);
assert_eq!(pty_config.dimensions.1, 40);
}
#[cfg(unix)]
#[tokio::test]
async fn spawn_rejects_null_byte_in_command() {
let spawner = PtySpawner::new();
let result = spawner.spawn("test\0command", &[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("null byte"),
"Expected error about null byte, got: {err_str}"
);
}
#[cfg(unix)]
#[tokio::test]
async fn spawn_rejects_null_byte_in_args() {
let spawner = PtySpawner::new();
let result = spawner
.spawn("/bin/echo", &["hello\0world".to_string()])
.await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("null byte"),
"Expected error about null byte, got: {err_str}"
);
}
}