use std::path::{Path, PathBuf};
use std::time::Duration;
use async_trait::async_trait;
use portable_pty::{CommandBuilder, MasterPty, PtySize, native_pty_system};
use tokio::sync::mpsc;
use tracing::{debug, trace, warn};
use crate::core::{AgentEvent, ClientFrame, Content, StopReason, TextChannel, Usage};
use crate::driver::{Driver, DriverError};
pub trait AgentParser: Send + 'static {
fn name(&self) -> &str;
fn on_bytes(&mut self, bytes: &[u8]) -> Vec<AgentEvent>;
fn on_eof(&mut self) -> Vec<AgentEvent> {
Vec::new()
}
}
#[derive(Debug, Default)]
pub struct RawParser;
impl AgentParser for RawParser {
fn name(&self) -> &str {
"raw"
}
fn on_bytes(&mut self, bytes: &[u8]) -> Vec<AgentEvent> {
if bytes.is_empty() {
return Vec::new();
}
vec![AgentEvent::TextChunk {
msg_id: String::new(),
text: String::from_utf8_lossy(bytes).into_owned(),
channel: TextChannel::Assistant,
}]
}
fn on_eof(&mut self) -> Vec<AgentEvent> {
vec![AgentEvent::Done {
stop_reason: StopReason::EndTurn,
usage: Usage::default(),
}]
}
}
pub struct VtPlainParser {
vt: vt100::Parser,
last_screen: String,
}
impl std::fmt::Debug for VtPlainParser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VtPlainParser")
.field("last_screen_len", &self.last_screen.len())
.finish()
}
}
impl VtPlainParser {
pub fn new(rows: u16, cols: u16) -> Self {
Self {
vt: vt100::Parser::new(rows, cols, 10_000),
last_screen: String::new(),
}
}
}
impl AgentParser for VtPlainParser {
fn name(&self) -> &str {
"vt100-plain"
}
fn on_bytes(&mut self, bytes: &[u8]) -> Vec<AgentEvent> {
self.vt.process(bytes);
let screen = self.vt.screen().contents();
if screen == self.last_screen {
return Vec::new();
}
let delta = if screen.starts_with(&self.last_screen) {
screen[self.last_screen.len()..].to_string()
} else {
format!("\n--- screen repaint ---\n{}", screen)
};
self.last_screen = screen;
if delta.is_empty() {
return Vec::new();
}
vec![AgentEvent::TextChunk {
msg_id: String::new(),
text: delta,
channel: TextChannel::Assistant,
}]
}
fn on_eof(&mut self) -> Vec<AgentEvent> {
vec![AgentEvent::Done {
stop_reason: StopReason::EndTurn,
usage: Usage::default(),
}]
}
}
pub struct PtyDriver {
input_tx: Option<mpsc::Sender<Vec<u8>>>,
event_rx: mpsc::Receiver<AgentEvent>,
master: Box<dyn MasterPty + Send>,
exited: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
impl std::fmt::Debug for PtyDriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PtyDriver")
.field("input_open", &self.input_tx.is_some())
.field("exited", &self.exited.load(std::sync::atomic::Ordering::Relaxed))
.finish()
}
}
impl PtyDriver {
pub fn builder(command: impl Into<String>) -> PtyDriverBuilder {
PtyDriverBuilder {
command: command.into(),
args: Vec::new(),
cwd: None,
env: Vec::new(),
env_remove: Vec::new(),
size: PtySize {
rows: 50,
cols: 200,
pixel_width: 0,
pixel_height: 0,
},
}
}
pub fn resize(&self, rows: u16, cols: u16) -> Result<(), DriverError> {
self.master
.resize(PtySize {
rows,
cols,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| DriverError::Io(std::io::Error::other(e.to_string())))
}
pub fn close_input(&mut self) {
self.input_tx = None;
}
pub async fn send_bytes(&mut self, bytes: &[u8]) -> Result<(), DriverError> {
let tx = self.input_tx.as_ref().ok_or(DriverError::AgentExited)?;
tx.send(bytes.to_vec())
.await
.map_err(|_| DriverError::AgentExited)?;
Ok(())
}
}
#[async_trait]
impl Driver for PtyDriver {
async fn send(&mut self, frame: ClientFrame) -> Result<(), DriverError> {
match frame {
ClientFrame::Prompt { content } => {
for c in content {
if let Content::Text(t) = c {
self.send_bytes(t.as_bytes()).await?;
}
}
self.send_bytes(b"\r").await?;
Ok(())
}
ClientFrame::Cancel => {
self.send_bytes(b"\x03").await
}
ClientFrame::AskUserAnswer { value, .. } => {
let text = value
.as_str()
.map(String::from)
.unwrap_or_else(|| value.to_string());
self.send_bytes(text.as_bytes()).await?;
self.send_bytes(b"\r").await
}
ClientFrame::PermissionResponse { decision, .. } => {
use crate::core::PermissionDecision::*;
let key: &[u8] = match decision {
AllowOnce | AllowAlways => b"y\r",
_ => b"n\r",
};
self.send_bytes(key).await
}
}
}
async fn next_event(&mut self) -> Option<AgentEvent> {
self.event_rx.recv().await
}
async fn shutdown(&mut self) -> Result<(), DriverError> {
self.input_tx = None;
Ok(())
}
}
#[derive(Debug)]
pub struct PtyDriverBuilder {
command: String,
args: Vec<String>,
cwd: Option<PathBuf>,
env: Vec<(String, String)>,
env_remove: Vec<String>,
size: PtySize,
}
impl PtyDriverBuilder {
pub fn arg(mut self, a: impl Into<String>) -> Self {
self.args.push(a.into());
self
}
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for a in args {
self.args.push(a.into());
}
self
}
pub fn cwd(mut self, p: impl AsRef<Path>) -> Self {
self.cwd = Some(p.as_ref().to_path_buf());
self
}
pub fn env(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
self.env.push((k.into(), v.into()));
self
}
pub fn env_remove(mut self, k: impl Into<String>) -> Self {
self.env_remove.push(k.into());
self
}
pub fn size(mut self, rows: u16, cols: u16) -> Self {
self.size.rows = rows;
self.size.cols = cols;
self
}
pub fn spawn<P: AgentParser>(self, parser: P) -> Result<PtyDriver, DriverError> {
let PtyDriverBuilder {
command,
args,
cwd,
env,
env_remove,
size,
} = self;
let pty_system = native_pty_system();
let pair = pty_system
.openpty(size)
.map_err(|e| DriverError::SpawnFailed(std::io::Error::other(e.to_string())))?;
let mut builder = CommandBuilder::new(&command);
builder.env_clear();
for (k, v) in std::env::vars_os() {
let k_str = k.to_string_lossy();
if env_remove.iter().any(|r| *r == *k_str) {
continue;
}
builder.env(k, v);
}
for a in args {
builder.arg(a);
}
if let Some(p) = cwd {
builder.cwd(p);
}
for (k, v) in env {
builder.env(k, v);
}
debug!(command = %command, "spawning PTY agent");
let child = pair
.slave
.spawn_command(builder)
.map_err(|e| DriverError::SpawnFailed(std::io::Error::other(e.to_string())))?;
let reader = pair
.master
.try_clone_reader()
.map_err(|e| DriverError::Io(std::io::Error::other(e.to_string())))?;
let writer = pair
.master
.take_writer()
.map_err(|e| DriverError::Io(std::io::Error::other(e.to_string())))?;
let (input_tx, input_rx) = mpsc::channel::<Vec<u8>>(64);
let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(256);
let exited = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
spawn_reader_thread(reader, parser, event_tx.clone(), std::sync::Arc::clone(&exited));
spawn_writer_thread(writer, input_rx);
spawn_child_waiter(child, event_tx, std::sync::Arc::clone(&exited));
drop(pair.slave);
Ok(PtyDriver {
input_tx: Some(input_tx),
event_rx,
master: pair.master,
exited,
})
}
}
fn spawn_reader_thread<P: AgentParser>(
mut reader: Box<dyn std::io::Read + Send>,
mut parser: P,
tx: mpsc::Sender<AgentEvent>,
exited: std::sync::Arc<std::sync::atomic::AtomicBool>,
) {
std::thread::Builder::new()
.name("cap-rs-pty-reader".into())
.spawn(move || {
let mut buf = [0u8; 8192];
loop {
match reader.read(&mut buf) {
Ok(0) => {
trace!("PTY reader: EOF");
break;
}
Ok(n) => {
let events = parser.on_bytes(&buf[..n]);
for ev in events {
if tx.blocking_send(ev).is_err() {
trace!("PTY reader: receiver dropped, exiting");
return;
}
}
}
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => {
warn!(error = %e, "PTY reader: read error");
break;
}
}
}
for ev in parser.on_eof() {
let _ = tx.blocking_send(ev);
}
exited.store(true, std::sync::atomic::Ordering::Relaxed);
})
.expect("failed to spawn PTY reader thread");
}
fn spawn_writer_thread(
mut writer: Box<dyn std::io::Write + Send>,
mut rx: mpsc::Receiver<Vec<u8>>,
) {
std::thread::Builder::new()
.name("cap-rs-pty-writer".into())
.spawn(move || {
while let Some(bytes) = rx.blocking_recv() {
if let Err(e) = writer.write_all(&bytes) {
warn!(error = %e, "PTY writer: write failed");
return;
}
if let Err(e) = writer.flush() {
warn!(error = %e, "PTY writer: flush failed");
return;
}
}
trace!("PTY writer: input channel closed, exiting");
})
.expect("failed to spawn PTY writer thread");
}
fn spawn_child_waiter(
mut child: Box<dyn portable_pty::Child + Send + Sync>,
event_tx: mpsc::Sender<AgentEvent>,
exited: std::sync::Arc<std::sync::atomic::AtomicBool>,
) {
std::thread::Builder::new()
.name("cap-rs-pty-waiter".into())
.spawn(move || {
let _ = child.wait();
std::thread::sleep(Duration::from_millis(50));
exited.store(true, std::sync::atomic::Ordering::Relaxed);
drop(event_tx);
})
.expect("failed to spawn PTY child waiter thread");
}