use std::os::fd::{AsFd, AsRawFd, OwnedFd};
use std::sync::LazyLock;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, bail};
use nix::poll::{PollFd, PollFlags, poll};
use nix::pty::{OpenptyResult, openpty};
use nix::sys::signal::Signal;
use nix::unistd::{ForkResult, Pid, close, dup2, execvp, fork, setsid};
use regex::Regex;
static ANSI_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\x1b\[[0-9;]*[A-Za-z]|\x1b\[K|\x1b\[2K").unwrap());
pub struct DebuggerProcess {
master: OwnedFd,
child_pid: Pid,
prompt_re: Regex,
}
impl DebuggerProcess {
pub fn spawn(
bin: &str,
args: &[String],
env_extra: &[(String, String)],
prompt_pattern: &str,
) -> Result<Self> {
let OpenptyResult { master, slave } = openpty(None, None)?;
let fork_result = unsafe { fork() }?;
match fork_result {
ForkResult::Child => {
drop(master);
setsid().ok();
let slave_fd = slave.as_raw_fd();
dup2(slave_fd, 0).ok();
dup2(slave_fd, 1).ok();
dup2(slave_fd, 2).ok();
if slave_fd > 2 {
close(slave_fd).ok();
}
unsafe {
for (k, v) in env_extra {
std::env::set_var(k, v);
}
std::env::set_var("TERM", "dumb");
}
let c_bin =
std::ffi::CString::new(bin).unwrap_or_else(|_| std::process::exit(127));
let mut c_args = vec![c_bin.clone()];
for a in args {
c_args.push(
std::ffi::CString::new(a.as_str())
.unwrap_or_else(|_| std::process::exit(127)),
);
}
execvp(&c_bin, &c_args).ok();
std::process::exit(127);
}
ForkResult::Parent { child } => {
drop(slave);
let prompt_re = Regex::new(prompt_pattern)
.context("invalid prompt pattern")?;
Ok(Self {
master,
child_pid: child,
prompt_re,
})
}
}
}
fn write_master(&self, data: &[u8]) -> Result<()> {
let fd = self.master.as_raw_fd();
let mut written = 0;
while written < data.len() {
match nix::unistd::write(unsafe { std::os::fd::BorrowedFd::borrow_raw(fd) }, &data[written..]) {
Ok(n) => written += n,
Err(nix::errno::Errno::EINTR) => continue,
Err(e) => return Err(e.into()),
}
}
Ok(())
}
fn read_master(&self, buf: &mut [u8]) -> usize {
nix::unistd::read(self.master.as_raw_fd(), buf).unwrap_or(0)
}
pub fn wait_for_prompt(&self, timeout: Duration) -> Result<String> {
self.read_until_prompt(timeout)
}
pub fn send_and_wait(&self, cmd: &str, timeout: Duration) -> Result<String> {
self.write_master(format!("{cmd}\n").as_bytes())?;
let raw = self.read_until_prompt(timeout)?;
let clean = strip_ansi(&raw);
let no_prompts = self.prompt_re.replace_all(&clean, "");
let lines: Vec<&str> = no_prompts.lines().collect();
let start = if !lines.is_empty() && lines[0].contains(cmd.trim()) {
1
} else {
0
};
let mut end = lines.len();
while end > start && lines[end - 1].trim().is_empty() {
end -= 1;
}
let output = lines[start..end].join("\n");
Ok(output.trim().to_string())
}
pub fn is_alive(&self) -> bool {
nix::sys::wait::waitpid(self.child_pid, Some(nix::sys::wait::WaitPidFlag::WNOHANG))
.is_ok_and(|s| matches!(s, nix::sys::wait::WaitStatus::StillAlive))
}
pub fn quit(&self, quit_cmd: &str) {
if self.is_alive() {
let _ = self.write_master(format!("{quit_cmd}\n").as_bytes());
std::thread::sleep(Duration::from_millis(500));
if self.is_alive() {
let _ = nix::sys::signal::kill(self.child_pid, Signal::SIGKILL);
}
}
}
fn read_until_prompt(&self, timeout: Duration) -> Result<String> {
let mut buf = [0u8; 4096];
let mut accumulated = String::new();
let start = Instant::now();
loop {
let remaining = timeout.saturating_sub(start.elapsed());
if remaining.is_zero() {
bail!("timeout waiting for prompt");
}
let fd = PollFd::new(self.master.as_fd(), PollFlags::POLLIN);
let ms = remaining.as_millis().min(u16::MAX as u128) as u16;
let n = poll(&mut [fd], ms)?;
if n == 0 {
if self.prompt_re.is_match(&strip_ansi(&accumulated)) {
break;
}
continue;
}
let bytes_read = self.read_master(&mut buf);
if bytes_read == 0 {
break;
}
accumulated.push_str(&String::from_utf8_lossy(&buf[..bytes_read]));
let cleaned = strip_ansi(&accumulated);
if self.prompt_re.is_match(&cleaned) {
std::thread::sleep(Duration::from_millis(20));
let extra_fd = PollFd::new(self.master.as_fd(), PollFlags::POLLIN);
if poll(&mut [extra_fd], 30u16).unwrap_or(0) > 0 {
let extra = self.read_master(&mut buf);
if extra > 0 {
accumulated.push_str(&String::from_utf8_lossy(&buf[..extra]));
}
}
break;
}
}
Ok(accumulated)
}
}
fn strip_ansi(s: &str) -> String {
if !s.contains('\x1b') {
return s.to_string();
}
ANSI_RE.replace_all(s, "").to_string()
}
impl Drop for DebuggerProcess {
fn drop(&mut self) {
let _ = nix::sys::signal::kill(self.child_pid, Signal::SIGTERM);
}
}