Skip to main content

harness/
process.rs

1use std::path::Path;
2use std::process::Stdio;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5
6use tokio::io::{AsyncBufReadExt, BufReader};
7use tokio::process::Command;
8use tokio_util::sync::CancellationToken;
9
10use crate::config::TaskConfig;
11use crate::error::{Error, Result};
12use crate::event::Event;
13use crate::runner::{AgentRunner, EventStream};
14
15/// Maximum bytes we'll collect from stderr before truncating.
16const MAX_STDERR_BYTES: usize = 64 * 1024;
17
18/// Guard that kills a child process group on drop.
19///
20/// On Unix, we send SIGTERM to the process group, wait up to 2s, then SIGKILL.
21/// The guard is wrapped in `Arc` so dropping the stream kills the child.
22pub(crate) struct ChildGuard {
23    pid: u32,
24    killed: AtomicBool,
25}
26
27impl ChildGuard {
28    fn new(pid: u32) -> Self {
29        Self {
30            pid,
31            killed: AtomicBool::new(false),
32        }
33    }
34
35    /// Actively kill the process group (SIGTERM, then SIGKILL after 2s).
36    ///
37    /// Safe to call multiple times — only the first call sends signals.
38    #[cfg(unix)]
39    pub(crate) fn kill(&self) {
40        // Ensure we only send signals once.
41        if self.killed.swap(true, Ordering::SeqCst) {
42            return;
43        }
44
45        use nix::sys::signal::{killpg, Signal};
46        use nix::unistd::Pid;
47
48        let pgid = Pid::from_raw(self.pid as i32);
49        if let Err(e) = killpg(pgid, Signal::SIGTERM) {
50            tracing::debug!("SIGTERM to pgid {} failed: {e}", self.pid);
51            return; // Process already gone, no need for SIGKILL.
52        }
53
54        let pid = self.pid;
55        std::thread::spawn(move || {
56            std::thread::sleep(std::time::Duration::from_secs(2));
57            let pgid = Pid::from_raw(pid as i32);
58            if let Err(e) = killpg(pgid, Signal::SIGKILL) {
59                tracing::debug!("SIGKILL to pgid {} failed: {e}", pid);
60            }
61        });
62    }
63
64    #[cfg(windows)]
65    pub(crate) fn kill(&self) {
66        if self.killed.swap(true, Ordering::SeqCst) {
67            return;
68        }
69        if let Err(e) = std::process::Command::new("taskkill")
70            .args(["/PID", &self.pid.to_string(), "/T", "/F"])
71            .output()
72        {
73            tracing::debug!("taskkill for pid {} failed: {e}", self.pid);
74        }
75    }
76
77    #[cfg(not(any(unix, windows)))]
78    pub(crate) fn kill(&self) {
79        if self.killed.swap(true, Ordering::SeqCst) {
80            return;
81        }
82        tracing::warn!("process cleanup not supported on this platform (pid={})", self.pid);
83    }
84}
85
86impl Drop for ChildGuard {
87    fn drop(&mut self) {
88        self.kill();
89    }
90}
91
92/// A handle that bundles an `EventStream` with a `CancellationToken`.
93///
94/// Cancelling the token gracefully stops the stream and kills the subprocess.
95pub struct StreamHandle {
96    /// The unified event stream.
97    pub stream: EventStream,
98    /// Cancel this to stop the agent subprocess.
99    pub cancel_token: CancellationToken,
100}
101
102/// Spawns an agent subprocess and returns a `StreamHandle` containing an
103/// `EventStream` and a `CancellationToken`.
104///
105/// This is the shared scaffolding used by every adapter — the only thing that
106/// differs per-agent is arg construction and line parsing.
107///
108/// The parser function returns a `Vec` so that a single JSON line can produce
109/// multiple events (e.g., an assistant message with both text + tool_use blocks).
110///
111/// If `cancel_token` is `None`, a new token is created internally.
112pub async fn spawn_and_stream<F>(
113    runner: &dyn AgentRunner,
114    config: &TaskConfig,
115    parse_line: F,
116    cancel_token: Option<CancellationToken>,
117) -> Result<StreamHandle>
118where
119    F: Fn(&str) -> Vec<Result<Event>> + Send + Sync + 'static,
120{
121    let binary = runner.binary_path(config)?;
122    let args = runner.build_args(config);
123    let env_vars = runner.build_env(config);
124
125    let cwd = config
126        .cwd
127        .clone()
128        .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
129
130    validate_cwd(&cwd)?;
131
132    tracing::debug!(
133        agent = runner.name(),
134        binary = %binary.display(),
135        args = ?args,
136        cwd = %cwd.display(),
137        "spawning agent process"
138    );
139
140    let mut cmd = Command::new(&binary);
141    cmd.args(&args)
142        .current_dir(&cwd)
143        .stdin(Stdio::null())
144        .stdout(Stdio::piped())
145        .stderr(Stdio::piped());
146
147    // On Unix, create a new process group so we can kill the entire tree.
148    #[cfg(unix)]
149    cmd.process_group(0);
150
151    for (k, v) in &env_vars {
152        cmd.env(k, v);
153    }
154
155    // Forward any user-supplied env vars.
156    for (k, v) in &config.env {
157        cmd.env(k, v);
158    }
159
160    let mut child = cmd.spawn().map_err(Error::SpawnFailed)?;
161
162    let child_pid = child
163        .id()
164        .ok_or_else(|| Error::Other("failed to get child process ID".into()))?;
165    let guard = Arc::new(ChildGuard::new(child_pid));
166
167    let stdout = child
168        .stdout
169        .take()
170        .ok_or_else(|| Error::Other("failed to capture stdout".into()))?;
171
172    let stderr = child
173        .stderr
174        .take()
175        .ok_or_else(|| Error::Other("failed to capture stderr".into()))?;
176
177    // Spawn a task to collect stderr for error reporting (capped at MAX_STDERR_BYTES).
178    let stderr_handle = tokio::spawn(async move {
179        let reader = BufReader::new(stderr);
180        let mut lines = reader.lines();
181        let mut buf = String::new();
182        while let Ok(Some(line)) = lines.next_line().await {
183            if buf.len() >= MAX_STDERR_BYTES {
184                break;
185            }
186            if !buf.is_empty() {
187                buf.push('\n');
188            }
189            let remaining = MAX_STDERR_BYTES - buf.len();
190            if line.len() > remaining {
191                buf.push_str(&line[..remaining]);
192                break;
193            }
194            buf.push_str(&line);
195        }
196        buf
197    });
198
199    // Spawn a task to wait for exit status.
200    let wait_handle = tokio::spawn(async move { child.wait().await });
201
202    let mut reader = BufReader::new(stdout).lines();
203
204    // Create or use the provided cancellation token.
205    let token = cancel_token.unwrap_or_default();
206    let token_for_task = token.clone();
207
208    // Use an mpsc channel so a spawned task can select! between line reads
209    // and cancellation — this ensures cancellation is responsive even when
210    // the subprocess is blocking (e.g. sleeping).
211    let (tx, rx) = tokio::sync::mpsc::channel::<Result<Event>>(256);
212
213    tokio::spawn(async move {
214        // Keep guard alive for the duration of this task.
215        let _guard = guard;
216
217        loop {
218            tokio::select! {
219                _ = token_for_task.cancelled() => {
220                    // Cancel requested — kill the subprocess and stop.
221                    _guard.kill();
222                    break;
223                }
224                line_result = reader.next_line() => {
225                    match line_result {
226                        Ok(Some(line)) => {
227                            if line.trim().is_empty() {
228                                continue;
229                            }
230                            let events = parse_line(&line);
231                            for result in events {
232                                let stamped = result.map(|e| e.stamp());
233                                if tx.send(stamped).await.is_err() {
234                                    return; // receiver dropped
235                                }
236                            }
237                        }
238                        Ok(None) => break, // EOF
239                        Err(e) => {
240                            let _ = tx.send(Err(Error::Io(e))).await;
241                            break;
242                        }
243                    }
244                }
245            }
246        }
247
248        // If we were cancelled, don't bother waiting for exit status.
249        if token_for_task.is_cancelled() {
250            return;
251        }
252
253        // After stdout closes, check exit status.
254        match wait_handle.await {
255            Ok(Ok(status)) if !status.success() => {
256                let stderr_text = stderr_handle.await.unwrap_or_default();
257                let code = status.code().unwrap_or(-1);
258                let _ = tx
259                    .send(Err(Error::ProcessFailed {
260                        code,
261                        stderr: stderr_text,
262                    }))
263                    .await;
264            }
265            Ok(Err(e)) => {
266                let _ = tx.send(Err(Error::Io(e))).await;
267            }
268            Err(e) => {
269                let _ = tx
270                    .send(Err(Error::Other(format!("join error: {e}"))))
271                    .await;
272            }
273            _ => {} // success — adapter should have emitted Result event
274        }
275    });
276
277    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
278
279    Ok(StreamHandle {
280        stream: Box::pin(stream),
281        cancel_token: token,
282    })
283}
284
285fn validate_cwd(cwd: &Path) -> Result<()> {
286    if !cwd.exists() {
287        return Err(Error::InvalidWorkDir(cwd.to_path_buf()));
288    }
289    if !cwd.is_dir() {
290        return Err(Error::Other(format!(
291            "working directory is not a directory: {}",
292            cwd.display()
293        )));
294    }
295    Ok(())
296}