use tokio::process::Command;
use tokio::time::Duration;
use crate::agent::tools::ToolError;
#[derive(Debug)]
pub(crate) struct InterleavedOutput {
pub merged: String,
pub exit_code: i32,
}
#[cfg(unix)]
struct PgKillGuard {
pid: u32,
armed: bool,
}
#[cfg(unix)]
impl PgKillGuard {
fn new(pid: u32) -> Self {
Self { pid, armed: true }
}
fn disarm(&mut self) {
self.armed = false;
}
}
#[cfg(unix)]
impl Drop for PgKillGuard {
fn drop(&mut self) {
if !self.armed {
return;
}
unsafe {
let _ = libc::kill(-(self.pid as libc::pid_t), libc::SIGKILL);
}
}
}
pub(crate) async fn run_with_timeout(
cmd: Command,
secs: u64,
) -> Result<InterleavedOutput, ToolError> {
use std::process::Stdio;
let mut cmd = cmd;
cmd.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
cmd.kill_on_drop(true);
#[cfg(unix)]
{
cmd.process_group(0);
}
let mut child = cmd
.spawn()
.map_err(|e| ToolError::Msg(format!("failed to spawn: {}", e)))?;
let pid = child.id();
#[cfg(not(unix))]
let _ = pid;
#[cfg(unix)]
let _pgguard = pid.map(PgKillGuard::new);
let stdout = child.stdout.take();
let stderr = child.stderr.take();
let drain = async move {
use tokio::io::AsyncBufReadExt;
let mut merged = String::new();
let mut so = stdout.map(tokio::io::BufReader::new);
let mut se = stderr.map(tokio::io::BufReader::new);
const DRAIN_CAP_BYTES: usize = 256 * 1024;
let mut overflow_bytes: usize = 0;
loop {
let has_so = so.is_some();
let has_se = se.is_some();
if !has_so && !has_se {
break;
}
let mut so_buf = String::new();
let mut se_buf = String::new();
let so_fut = async {
match so.as_mut() {
Some(r) => r.read_line(&mut so_buf).await.map(Some),
None => Ok::<_, std::io::Error>(None),
}
};
let se_fut = async {
match se.as_mut() {
Some(r) => r.read_line(&mut se_buf).await.map(Some),
None => Ok::<_, std::io::Error>(None),
}
};
tokio::select! {
biased;
r = so_fut, if has_so => match r {
Ok(Some(0)) | Ok(None) | Err(_) => { so = None; }
Ok(Some(n)) => {
if merged.len() < DRAIN_CAP_BYTES {
merged.push_str(&so_buf);
} else {
overflow_bytes = overflow_bytes.saturating_add(n);
}
},
},
r = se_fut, if has_se => match r {
Ok(Some(0)) | Ok(None) | Err(_) => { se = None; }
Ok(Some(n)) => {
if merged.len() < DRAIN_CAP_BYTES {
merged.push_str(&se_buf);
} else {
overflow_bytes = overflow_bytes.saturating_add(n);
}
},
},
}
}
if overflow_bytes > 0 {
if !merged.is_empty() && !merged.ends_with('\n') {
merged.push('\n');
}
merged.push_str(&format!(
"…[bash output exceeded cap; discarded {} additional bytes streamed after the {}-KiB cap]",
overflow_bytes,
DRAIN_CAP_BYTES / 1024,
));
}
merged
};
let wait = async {
let merged = drain.await;
let status = child.wait().await?;
Ok::<_, std::io::Error>((merged, status))
};
let outcome = tokio::time::timeout(Duration::from_secs(secs), wait).await;
match outcome {
Ok(Ok((merged, status))) => {
#[cfg(unix)]
{
let mut g = _pgguard;
if let Some(ref mut gg) = g {
gg.disarm();
}
}
Ok(InterleavedOutput {
merged,
exit_code: status.code().unwrap_or(-1),
})
}
Ok(Err(e)) => Err(ToolError::Msg(format!("wait failed: {}", e))),
Err(_) => {
#[cfg(unix)]
{
let mut g = _pgguard;
if let Some(ref mut gg) = g {
gg.disarm();
}
if let Some(pid) = pid {
unsafe {
let _ = libc::kill(-(pid as libc::pid_t), libc::SIGKILL);
}
}
}
let _ = pid;
Err(ToolError::Msg(format!("Command timed out after {}s", secs)))
}
}
}
pub(super) fn spawn_streaming_shell(
cmd: Command,
store: crate::agent::tools::bg_shell::BackgroundShellStore,
id: String,
timeout: Option<u64>,
) -> tokio::task::JoinHandle<()> {
use crate::agent::tools::bg_shell::ShellStatus;
use std::process::Stdio;
struct FinishOnDrop {
store: crate::agent::tools::bg_shell::BackgroundShellStore,
id: String,
}
impl Drop for FinishOnDrop {
fn drop(&mut self) {
self.store.finish(
&self.id,
ShellStatus::Failed("drain task ended without recording an exit status".into()),
);
}
}
tokio::spawn(async move {
let _finish_backstop = FinishOnDrop {
store: store.clone(),
id: id.clone(),
};
let mut cmd = cmd;
cmd.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
#[cfg(unix)]
cmd.process_group(0);
let mut child = match cmd.spawn() {
Ok(c) => c,
Err(e) => {
store.finish(&id, ShellStatus::Failed(format!("failed to spawn: {e}")));
return;
}
};
let pid = child.id();
#[cfg(not(unix))]
let _ = pid;
#[cfg(unix)]
let mut pgguard = pid.map(PgKillGuard::new);
let stdout = child.stdout.take();
let stderr = child.stderr.take();
let drain = {
let store = store.clone();
let id = id.clone();
async move {
use tokio::io::AsyncBufReadExt;
let mut so = stdout.map(tokio::io::BufReader::new);
let mut se = stderr.map(tokio::io::BufReader::new);
loop {
let has_so = so.is_some();
let has_se = se.is_some();
if !has_so && !has_se {
break;
}
let mut so_buf = String::new();
let mut se_buf = String::new();
let so_fut = async {
match so.as_mut() {
Some(r) => r.read_line(&mut so_buf).await.map(Some),
None => Ok::<_, std::io::Error>(None),
}
};
let se_fut = async {
match se.as_mut() {
Some(r) => r.read_line(&mut se_buf).await.map(Some),
None => Ok::<_, std::io::Error>(None),
}
};
tokio::select! {
biased;
r = so_fut, if has_so => match r {
Ok(Some(0)) | Ok(None) | Err(_) => { so = None; }
Ok(Some(_)) => store.append(&id, &so_buf),
},
r = se_fut, if has_se => match r {
Ok(Some(0)) | Ok(None) | Err(_) => { se = None; }
Ok(Some(_)) => store.append(&id, &se_buf),
},
}
}
}
};
let wait = async {
drain.await;
child.wait().await
};
let status = match timeout {
Some(secs) => match tokio::time::timeout(Duration::from_secs(secs), wait).await {
Ok(Ok(st)) => ShellStatus::Exited(st.code().unwrap_or(-1)),
Ok(Err(e)) => ShellStatus::Failed(e.to_string()),
Err(_) => {
#[cfg(unix)]
if let Some(pid) = pid {
unsafe {
let _ = libc::kill(-(pid as libc::pid_t), libc::SIGKILL);
}
}
let _ = pid;
ShellStatus::Failed(format!("auto-killed after {secs}s timeout"))
}
},
None => match wait.await {
Ok(st) => ShellStatus::Exited(st.code().unwrap_or(-1)),
Err(e) => ShellStatus::Failed(e.to_string()),
},
};
#[cfg(unix)]
if let Some(g) = pgguard.as_mut() {
g.disarm();
}
store.finish(&id, status);
})
}