1use std::path::{Path, PathBuf};
17use std::time::Duration;
18
19use async_trait::async_trait;
20use portable_pty::{CommandBuilder, MasterPty, PtySize, native_pty_system};
21use tokio::sync::mpsc;
22use tracing::{debug, trace, warn};
23
24use crate::core::{AgentEvent, ClientFrame, Content, StopReason, TextChannel, Usage};
25use crate::driver::{Driver, DriverError};
26
27pub trait AgentParser: Send + 'static {
39 fn name(&self) -> &str;
41
42 fn on_bytes(&mut self, bytes: &[u8]) -> Vec<AgentEvent>;
46
47 fn on_eof(&mut self) -> Vec<AgentEvent> {
50 Vec::new()
51 }
52}
53
54#[derive(Debug, Default)]
58pub struct RawParser;
59
60impl AgentParser for RawParser {
61 fn name(&self) -> &str {
62 "raw"
63 }
64
65 fn on_bytes(&mut self, bytes: &[u8]) -> Vec<AgentEvent> {
66 if bytes.is_empty() {
67 return Vec::new();
68 }
69 vec![AgentEvent::TextChunk {
70 msg_id: String::new(),
71 text: String::from_utf8_lossy(bytes).into_owned(),
72 channel: TextChannel::Assistant,
73 }]
74 }
75
76 fn on_eof(&mut self) -> Vec<AgentEvent> {
77 vec![AgentEvent::Done {
78 stop_reason: StopReason::EndTurn,
79 usage: Usage::default(),
80 }]
81 }
82}
83
84pub struct VtPlainParser {
88 vt: vt100::Parser,
89 last_screen: String,
90}
91
92impl std::fmt::Debug for VtPlainParser {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("VtPlainParser")
95 .field("last_screen_len", &self.last_screen.len())
96 .finish()
97 }
98}
99
100impl VtPlainParser {
101 pub fn new(rows: u16, cols: u16) -> Self {
102 Self {
103 vt: vt100::Parser::new(rows, cols, 10_000),
104 last_screen: String::new(),
105 }
106 }
107}
108
109impl AgentParser for VtPlainParser {
110 fn name(&self) -> &str {
111 "vt100-plain"
112 }
113
114 fn on_bytes(&mut self, bytes: &[u8]) -> Vec<AgentEvent> {
115 self.vt.process(bytes);
116 let screen = self.vt.screen().contents();
117 if screen == self.last_screen {
118 return Vec::new();
119 }
120 let delta = if screen.starts_with(&self.last_screen) {
122 screen[self.last_screen.len()..].to_string()
123 } else {
124 format!("\n--- screen repaint ---\n{}", screen)
126 };
127 self.last_screen = screen;
128 if delta.is_empty() {
129 return Vec::new();
130 }
131 vec![AgentEvent::TextChunk {
132 msg_id: String::new(),
133 text: delta,
134 channel: TextChannel::Assistant,
135 }]
136 }
137
138 fn on_eof(&mut self) -> Vec<AgentEvent> {
139 vec![AgentEvent::Done {
140 stop_reason: StopReason::EndTurn,
141 usage: Usage::default(),
142 }]
143 }
144}
145
146pub struct PtyDriver {
154 input_tx: Option<mpsc::Sender<Vec<u8>>>,
157
158 event_rx: mpsc::Receiver<AgentEvent>,
160
161 master: Box<dyn MasterPty + Send>,
164
165 exited: std::sync::Arc<std::sync::atomic::AtomicBool>,
167}
168
169impl std::fmt::Debug for PtyDriver {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct("PtyDriver")
172 .field("input_open", &self.input_tx.is_some())
173 .field("exited", &self.exited.load(std::sync::atomic::Ordering::Relaxed))
174 .finish()
175 }
176}
177
178impl PtyDriver {
179 pub fn builder(command: impl Into<String>) -> PtyDriverBuilder {
181 PtyDriverBuilder {
182 command: command.into(),
183 args: Vec::new(),
184 cwd: None,
185 env: Vec::new(),
186 env_remove: Vec::new(),
187 size: PtySize {
188 rows: 50,
189 cols: 200,
190 pixel_width: 0,
191 pixel_height: 0,
192 },
193 }
194 }
195
196 pub fn resize(&self, rows: u16, cols: u16) -> Result<(), DriverError> {
198 self.master
199 .resize(PtySize {
200 rows,
201 cols,
202 pixel_width: 0,
203 pixel_height: 0,
204 })
205 .map_err(|e| DriverError::Io(std::io::Error::other(e.to_string())))
206 }
207
208 pub fn close_input(&mut self) {
212 self.input_tx = None;
213 }
214
215 pub async fn send_bytes(&mut self, bytes: &[u8]) -> Result<(), DriverError> {
218 let tx = self.input_tx.as_ref().ok_or(DriverError::AgentExited)?;
219 tx.send(bytes.to_vec())
220 .await
221 .map_err(|_| DriverError::AgentExited)?;
222 Ok(())
223 }
224}
225
226#[async_trait]
227impl Driver for PtyDriver {
228 async fn send(&mut self, frame: ClientFrame) -> Result<(), DriverError> {
229 match frame {
230 ClientFrame::Prompt { content } => {
231 for c in content {
232 if let Content::Text(t) = c {
233 self.send_bytes(t.as_bytes()).await?;
234 }
235 }
237 self.send_bytes(b"\r").await?;
239 Ok(())
240 }
241 ClientFrame::Cancel => {
242 self.send_bytes(b"\x03").await
244 }
245 ClientFrame::AskUserAnswer { value, .. } => {
246 let text = value
248 .as_str()
249 .map(String::from)
250 .unwrap_or_else(|| value.to_string());
251 self.send_bytes(text.as_bytes()).await?;
252 self.send_bytes(b"\r").await
253 }
254 ClientFrame::PermissionResponse { decision, .. } => {
255 use crate::core::PermissionDecision::*;
256 let key: &[u8] = match decision {
257 AllowOnce | AllowAlways => b"y\r",
258 _ => b"n\r",
259 };
260 self.send_bytes(key).await
261 }
262 }
263 }
264
265 async fn next_event(&mut self) -> Option<AgentEvent> {
266 self.event_rx.recv().await
267 }
268
269 async fn shutdown(&mut self) -> Result<(), DriverError> {
270 self.input_tx = None;
272 Ok(())
275 }
276}
277
278#[derive(Debug)]
283pub struct PtyDriverBuilder {
284 command: String,
285 args: Vec<String>,
286 cwd: Option<PathBuf>,
287 env: Vec<(String, String)>,
288 env_remove: Vec<String>,
289 size: PtySize,
290}
291
292impl PtyDriverBuilder {
293 pub fn arg(mut self, a: impl Into<String>) -> Self {
294 self.args.push(a.into());
295 self
296 }
297
298 pub fn args<I, S>(mut self, args: I) -> Self
299 where
300 I: IntoIterator<Item = S>,
301 S: Into<String>,
302 {
303 for a in args {
304 self.args.push(a.into());
305 }
306 self
307 }
308
309 pub fn cwd(mut self, p: impl AsRef<Path>) -> Self {
310 self.cwd = Some(p.as_ref().to_path_buf());
311 self
312 }
313
314 pub fn env(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
315 self.env.push((k.into(), v.into()));
316 self
317 }
318
319 pub fn env_remove(mut self, k: impl Into<String>) -> Self {
320 self.env_remove.push(k.into());
321 self
322 }
323
324 pub fn size(mut self, rows: u16, cols: u16) -> Self {
325 self.size.rows = rows;
326 self.size.cols = cols;
327 self
328 }
329
330 pub fn spawn<P: AgentParser>(self, parser: P) -> Result<PtyDriver, DriverError> {
332 let PtyDriverBuilder {
333 command,
334 args,
335 cwd,
336 env,
337 env_remove,
338 size,
339 } = self;
340
341 let pty_system = native_pty_system();
342 let pair = pty_system
343 .openpty(size)
344 .map_err(|e| DriverError::SpawnFailed(std::io::Error::other(e.to_string())))?;
345
346 let mut builder = CommandBuilder::new(&command);
350 builder.env_clear();
351 for (k, v) in std::env::vars_os() {
352 let k_str = k.to_string_lossy();
354 if env_remove.iter().any(|r| *r == *k_str) {
355 continue;
356 }
357 builder.env(k, v);
358 }
359 for a in args {
360 builder.arg(a);
361 }
362 if let Some(p) = cwd {
363 builder.cwd(p);
364 }
365 for (k, v) in env {
367 builder.env(k, v);
368 }
369
370 debug!(command = %command, "spawning PTY agent");
371
372 let child = pair
373 .slave
374 .spawn_command(builder)
375 .map_err(|e| DriverError::SpawnFailed(std::io::Error::other(e.to_string())))?;
376
377 let reader = pair
378 .master
379 .try_clone_reader()
380 .map_err(|e| DriverError::Io(std::io::Error::other(e.to_string())))?;
381 let writer = pair
382 .master
383 .take_writer()
384 .map_err(|e| DriverError::Io(std::io::Error::other(e.to_string())))?;
385
386 let (input_tx, input_rx) = mpsc::channel::<Vec<u8>>(64);
387 let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(256);
388 let exited = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
389
390 spawn_reader_thread(reader, parser, event_tx.clone(), std::sync::Arc::clone(&exited));
391 spawn_writer_thread(writer, input_rx);
392 spawn_child_waiter(child, event_tx, std::sync::Arc::clone(&exited));
393
394 drop(pair.slave);
396
397 Ok(PtyDriver {
398 input_tx: Some(input_tx),
399 event_rx,
400 master: pair.master,
401 exited,
402 })
403 }
404}
405
406fn spawn_reader_thread<P: AgentParser>(
411 mut reader: Box<dyn std::io::Read + Send>,
412 mut parser: P,
413 tx: mpsc::Sender<AgentEvent>,
414 exited: std::sync::Arc<std::sync::atomic::AtomicBool>,
415) {
416 std::thread::Builder::new()
417 .name("cap-rs-pty-reader".into())
418 .spawn(move || {
419 let mut buf = [0u8; 8192];
420 loop {
421 match reader.read(&mut buf) {
422 Ok(0) => {
423 trace!("PTY reader: EOF");
424 break;
425 }
426 Ok(n) => {
427 let events = parser.on_bytes(&buf[..n]);
428 for ev in events {
429 if tx.blocking_send(ev).is_err() {
430 trace!("PTY reader: receiver dropped, exiting");
431 return;
432 }
433 }
434 }
435 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
436 Err(e) => {
437 warn!(error = %e, "PTY reader: read error");
438 break;
439 }
440 }
441 }
442 for ev in parser.on_eof() {
443 let _ = tx.blocking_send(ev);
444 }
445 exited.store(true, std::sync::atomic::Ordering::Relaxed);
446 })
447 .expect("failed to spawn PTY reader thread");
448}
449
450fn spawn_writer_thread(
451 mut writer: Box<dyn std::io::Write + Send>,
452 mut rx: mpsc::Receiver<Vec<u8>>,
453) {
454 std::thread::Builder::new()
455 .name("cap-rs-pty-writer".into())
456 .spawn(move || {
457 while let Some(bytes) = rx.blocking_recv() {
458 if let Err(e) = writer.write_all(&bytes) {
459 warn!(error = %e, "PTY writer: write failed");
460 return;
461 }
462 if let Err(e) = writer.flush() {
463 warn!(error = %e, "PTY writer: flush failed");
464 return;
465 }
466 }
467 trace!("PTY writer: input channel closed, exiting");
468 })
469 .expect("failed to spawn PTY writer thread");
470}
471
472fn spawn_child_waiter(
473 mut child: Box<dyn portable_pty::Child + Send + Sync>,
474 event_tx: mpsc::Sender<AgentEvent>,
475 exited: std::sync::Arc<std::sync::atomic::AtomicBool>,
476) {
477 std::thread::Builder::new()
478 .name("cap-rs-pty-waiter".into())
479 .spawn(move || {
480 let _ = child.wait();
482 std::thread::sleep(Duration::from_millis(50));
484 exited.store(true, std::sync::atomic::Ordering::Relaxed);
485 drop(event_tx);
488 })
489 .expect("failed to spawn PTY child waiter thread");
490}