use anyhow::{Context, Result};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use portable_pty::{native_pty_system, CommandBuilder, PtySize};
use std::io::{Read, Write};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use crate::config::ExfilDetectionSettings;
use crate::ipc::client::IpcClient;
use crate::ipc::protocol::WrapState;
use crate::wrap::analyzer::Analyzer;
use crate::wrap::exfil_detector::ExfilDetector;
pub struct PtyRunnerConfig {
pub command: String,
pub args: Vec<String>,
pub id: String,
pub rows: u16,
pub cols: u16,
pub exfil_detection: ExfilDetectionSettings,
}
impl Default for PtyRunnerConfig {
fn default() -> Self {
Self {
command: String::new(),
args: Vec::new(),
id: uuid::Uuid::new_v4().to_string(),
rows: 24,
cols: 80,
exfil_detection: ExfilDetectionSettings::default(),
}
}
}
pub struct PtyRunner {
config: PtyRunnerConfig,
}
impl PtyRunner {
pub fn new(config: PtyRunnerConfig) -> Self {
Self { config }
}
pub fn run(self) -> Result<i32> {
let (rows, cols) = get_terminal_size().unwrap_or((self.config.rows, self.config.cols));
let pty_system = native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows,
cols,
pixel_width: 0,
pixel_height: 0,
})
.context("Failed to open PTY")?;
let mut cmd = CommandBuilder::new(&self.config.command);
cmd.args(&self.config.args);
if let Ok(cwd) = std::env::current_dir() {
cmd.cwd(cwd);
}
let mut child = pair
.slave
.spawn_command(cmd)
.context("Failed to spawn command")?;
let child_pid = child.process_id().unwrap_or(0);
tracing::debug!("Spawned {} with PID {}", self.config.command, child_pid);
let analyzer = Arc::new(parking_lot::Mutex::new(Analyzer::new(child_pid)));
let exfil_detector = Arc::new(ExfilDetector::new(&self.config.exfil_detection, child_pid));
let running = Arc::new(AtomicBool::new(true));
let mut master_reader = pair
.master
.try_clone_reader()
.context("Failed to clone PTY reader")?;
let master_writer = pair
.master
.take_writer()
.context("Failed to take PTY writer")?;
let master_writer_shared: Arc<parking_lot::Mutex<Box<dyn Write + Send>>> =
Arc::new(parking_lot::Mutex::new(master_writer));
let team_name = analyzer.lock().team_name().cloned();
let team_member_name = analyzer.lock().team_member_name().cloned();
let is_team_lead = analyzer.lock().is_team_lead();
let ipc_client = IpcClient::start(
self.config.id.clone(),
child_pid,
team_name,
team_member_name,
is_team_lead,
running.clone(),
master_writer_shared.clone(),
analyzer.clone(),
);
let analyzer_out = analyzer.clone();
let exfil_detector_out = exfil_detector.clone();
let running_out = running.clone();
let output_thread = thread::spawn(move || {
let mut stdout = std::io::stdout();
let mut buf = [0u8; 4096];
while running_out.load(Ordering::Relaxed) {
match master_reader.read(&mut buf) {
Ok(0) => break, Ok(n) => {
if stdout.write_all(&buf[..n]).is_err() {
break;
}
let _ = stdout.flush();
if let Ok(s) = std::str::from_utf8(&buf[..n]) {
analyzer_out.lock().process_output(s);
exfil_detector_out.check_output(s);
}
}
Err(e) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
tracing::debug!("PTY read error: {}", e);
break;
}
}
}
}
});
let analyzer_in = analyzer.clone();
let running_in = running.clone();
let writer_for_input = master_writer_shared;
let input_thread = thread::spawn(move || {
let stdin = std::io::stdin();
let mut stdin = stdin.lock();
let mut buf = [0u8; 1024];
while running_in.load(Ordering::Relaxed) {
match stdin.read(&mut buf) {
Ok(0) => break, Ok(n) => {
{
let mut writer = writer_for_input.lock();
if writer.write_all(&buf[..n]).is_err() {
break;
}
let _ = writer.flush();
}
if let Ok(s) = std::str::from_utf8(&buf[..n]) {
analyzer_in.lock().process_input(s);
}
}
Err(e) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
tracing::debug!("stdin read error: {}", e);
break;
}
}
}
}
});
let analyzer_state = analyzer.clone();
let running_state = running.clone();
let state_thread = thread::spawn(move || {
let mut last_state: Option<WrapState> = None;
while running_state.load(Ordering::Relaxed) {
thread::sleep(Duration::from_millis(100));
let state = analyzer_state.lock().get_state();
let should_send = match &last_state {
None => true,
Some(prev) => !states_equal(prev, &state),
};
if should_send {
ipc_client.send_state(state.clone());
last_state = Some(state);
}
}
});
let running_resize = running.clone();
let pty_master = pair.master;
let resize_thread = thread::spawn(move || {
let mut last_size: Option<(u16, u16)> = get_terminal_size();
while running_resize.load(Ordering::Relaxed) {
thread::sleep(Duration::from_millis(100));
let current_size = get_terminal_size();
if current_size != last_size {
if let Some((rows, cols)) = current_size {
let _ = pty_master.resize(PtySize {
rows,
cols,
pixel_width: 0,
pixel_height: 0,
});
}
last_size = current_size;
}
}
});
let exit_status = child.wait().context("Failed to wait for child")?;
running.store(false, Ordering::Relaxed);
join_thread_with_timeout(output_thread, Duration::from_secs(1));
join_thread_with_timeout(input_thread, Duration::from_secs(1));
join_thread_with_timeout(state_thread, Duration::from_secs(1));
join_thread_with_timeout(resize_thread, Duration::from_secs(1));
Ok(exit_status.exit_code() as i32)
}
}
fn states_equal(a: &WrapState, b: &WrapState) -> bool {
a.status == b.status
&& a.approval_type == b.approval_type
&& a.details == b.details
&& a.choices == b.choices
&& a.multi_select == b.multi_select
&& a.cursor_position == b.cursor_position
&& a.pid == b.pid
&& a.pane_id == b.pane_id
}
fn join_thread_with_timeout<T>(handle: JoinHandle<T>, timeout: Duration) {
let start = Instant::now();
loop {
if handle.is_finished() {
let _ = handle.join();
return;
}
if start.elapsed() >= timeout {
tracing::debug!("Thread join timed out, abandoning thread");
return;
}
thread::sleep(Duration::from_millis(10));
}
}
fn get_terminal_size() -> Option<(u16, u16)> {
use nix::libc;
let fd = libc::STDOUT_FILENO;
let mut size: libc::winsize = unsafe { std::mem::zeroed() };
let result = unsafe { libc::ioctl(fd, libc::TIOCGWINSZ, &mut size) };
if result == 0 && size.ws_row > 0 && size.ws_col > 0 {
Some((size.ws_row, size.ws_col))
} else {
None
}
}
pub fn forward_signal_to_child(child_pid: u32, sig: Signal) -> Result<()> {
if child_pid > 0 {
signal::kill(Pid::from_raw(child_pid as i32), sig).context("Failed to forward signal")?;
}
Ok(())
}
pub fn parse_command(cmd_str: &str) -> (String, Vec<String>) {
let parts: Vec<&str> = cmd_str.split_whitespace().collect();
if parts.is_empty() {
return (String::new(), Vec::new());
}
let command = parts[0].to_string();
let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
(command, args)
}
pub fn get_pane_id() -> String {
if let Ok(pane) = std::env::var("TMUX_PANE") {
return pane.trim_start_matches('%').to_string();
}
uuid::Uuid::new_v4().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_command_simple() {
let (cmd, args) = parse_command("claude");
assert_eq!(cmd, "claude");
assert!(args.is_empty());
}
#[test]
fn test_parse_command_with_args() {
let (cmd, args) = parse_command("claude --debug --config test.toml");
assert_eq!(cmd, "claude");
assert_eq!(args, vec!["--debug", "--config", "test.toml"]);
}
#[test]
fn test_parse_command_empty() {
let (cmd, args) = parse_command("");
assert!(cmd.is_empty());
assert!(args.is_empty());
}
#[test]
fn test_get_pane_id_fallback() {
let id = get_pane_id();
assert!(!id.is_empty());
}
}