Skip to main content

ito_core/harness/
streaming_cli.rs

1use super::types::{Harness, HarnessName, HarnessRunConfig, HarnessRunResult};
2use miette::{Result, miette};
3use std::io::Write;
4use std::process::{Command, Stdio};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::thread;
8use std::time::{Duration, Instant};
9
10/// Default inactivity timeout for CLI harnesses.
11pub const DEFAULT_INACTIVITY_TIMEOUT: Duration = Duration::from_secs(15 * 60);
12
13/// A CLI-based harness that spawns a binary with streaming I/O.
14///
15/// Implementors describe *which* binary to run and *how* to build its argument
16/// list. The blanket [`Harness`] impl handles spawning, streaming, inactivity
17/// monitoring, and result collection — so individual harnesses only need to
18/// provide a few declarative methods.
19///
20/// # Examples
21///
22/// ```
23/// use ito_core::harness::{Harness, HarnessName, HarnessRunConfig};
24/// use ito_core::harness::streaming_cli::CliHarness;
25///
26/// #[derive(Debug)]
27/// struct MyHarness;
28///
29/// impl CliHarness for MyHarness {
30///     fn harness_name(&self) -> HarnessName { HarnessName::Codex }
31///     fn binary(&self) -> &str { "codex" }
32///     fn build_args(&self, config: &HarnessRunConfig) -> Vec<String> {
33///         vec!["exec".into(), config.prompt.clone()]
34///     }
35/// }
36///
37/// let h = MyHarness;
38/// assert_eq!(h.harness_name(), HarnessName::Codex);
39/// ```
40pub trait CliHarness: std::fmt::Debug {
41    /// The harness identity (e.g. [`HarnessName::Claude`]).
42    fn harness_name(&self) -> HarnessName;
43
44    /// The CLI binary to spawn (e.g. `"claude"`, `"codex"`).
45    fn binary(&self) -> &str;
46
47    /// Build the full argument list for a single invocation.
48    ///
49    /// Called once per `Harness::run`. The returned args are passed directly
50    /// to the binary — the trait handles spawning and streaming.
51    fn build_args(&self, config: &HarnessRunConfig) -> Vec<String>;
52}
53
54/// Blanket impl: every [`CliHarness`] is automatically a [`Harness`].
55impl<T: CliHarness> Harness for T {
56    fn name(&self) -> HarnessName {
57        self.harness_name()
58    }
59
60    fn run(&mut self, config: &HarnessRunConfig) -> Result<HarnessRunResult> {
61        let args = self.build_args(config);
62        run_streaming_cli(self.binary(), &args, config)
63    }
64
65    fn stop(&mut self) {
66        // No-op: `run` is synchronous.
67    }
68
69    fn streams_output(&self) -> bool {
70        true
71    }
72}
73
74/// Which standard stream a pipe should forward output to.
75enum StreamTarget {
76    /// Forward to stdout.
77    Stdout,
78    /// Forward to stderr.
79    Stderr,
80}
81
82/// Spawns a CLI binary with streaming stdout/stderr and an inactivity monitor.
83///
84/// All harnesses delegate to this function so they share consistent streaming
85/// behaviour: output is forwarded to the terminal in real time, an inactivity
86/// timer kills the process when it stalls, and incomplete UTF-8 sequences at
87/// chunk boundaries are handled correctly.
88fn run_streaming_cli(
89    binary: &str,
90    args: &[String],
91    config: &HarnessRunConfig,
92) -> Result<HarnessRunResult> {
93    let mut cmd = Command::new(binary);
94    cmd.args(args);
95    cmd.current_dir(&config.cwd);
96    cmd.envs(&config.env);
97    cmd.stdout(Stdio::piped());
98    cmd.stderr(Stdio::piped());
99
100    let start = Instant::now();
101    let mut child = cmd
102        .spawn()
103        .map_err(|e| miette!("Failed to spawn {binary}: {e}"))?;
104
105    let child_id = child.id();
106    let stdout_pipe = child.stdout.take();
107    let stderr_pipe = child.stderr.take();
108
109    let last_activity = Arc::new(std::sync::Mutex::new(Instant::now()));
110    let timed_out = Arc::new(AtomicBool::new(false));
111    let done = Arc::new(AtomicBool::new(false));
112
113    let last_activity_stdout = Arc::clone(&last_activity);
114    let stdout_handle = thread::spawn(move || {
115        stream_pipe(stdout_pipe, &last_activity_stdout, StreamTarget::Stdout)
116    });
117
118    let last_activity_stderr = Arc::clone(&last_activity);
119    let stderr_handle = thread::spawn(move || {
120        stream_pipe(stderr_pipe, &last_activity_stderr, StreamTarget::Stderr)
121    });
122
123    let timeout = config
124        .inactivity_timeout
125        .unwrap_or(DEFAULT_INACTIVITY_TIMEOUT);
126    let last_activity_monitor = Arc::clone(&last_activity);
127    let timed_out_monitor = Arc::clone(&timed_out);
128    let done_monitor = Arc::clone(&done);
129
130    let monitor_handle = thread::spawn(move || {
131        monitor_timeout(
132            child_id,
133            timeout,
134            &last_activity_monitor,
135            &timed_out_monitor,
136            &done_monitor,
137        )
138    });
139
140    let status = child
141        .wait()
142        .map_err(|e| miette!("Failed to wait for {binary}: {e}"))?;
143    done.store(true, Ordering::SeqCst);
144
145    let stdout = stdout_handle.join().unwrap_or_default();
146    let stderr = stderr_handle.join().unwrap_or_default();
147    let _ = monitor_handle.join();
148
149    let duration = start.elapsed();
150    let was_timed_out = timed_out.load(Ordering::SeqCst);
151
152    let exit_code = if was_timed_out {
153        -1
154    } else {
155        exit_code_from_status(&status)
156    };
157
158    Ok(HarnessRunResult {
159        stdout,
160        stderr,
161        exit_code,
162        duration,
163        timed_out: was_timed_out,
164    })
165}
166
167/// Extracts an exit code from a [`std::process::ExitStatus`].
168///
169/// On Unix, when a process is killed by a signal, `ExitStatus::code()` returns
170/// `None` and the signal number is available via `ExitStatus::signal()`. This
171/// function converts signal termination to the conventional `128 + signal` exit
172/// code so that [`HarnessRunResult::is_retriable`] can detect crash signals
173/// (SIGSEGV, SIGBUS, etc.).
174fn exit_code_from_status(status: &std::process::ExitStatus) -> i32 {
175    if let Some(code) = status.code() {
176        return code;
177    }
178
179    #[cfg(unix)]
180    {
181        use std::os::unix::process::ExitStatusExt;
182        if let Some(signal) = status.signal() {
183            return 128 + signal;
184        }
185    }
186
187    // Fallback for platforms where neither code nor signal is available.
188    1
189}
190
191/// Reads from `pipe` in byte-level chunks, forwarding output to stdout/stderr
192/// and updating `last_activity` on every read. Byte-level reads (vs line-based)
193/// ensure inactivity is tracked even when tools stream output without newlines.
194///
195/// Incomplete UTF-8 sequences at chunk boundaries are buffered and prepended to
196/// the next read, so multi-byte characters are never split by replacement chars.
197fn stream_pipe(
198    pipe: Option<impl std::io::Read>,
199    last_activity: &std::sync::Mutex<Instant>,
200    target: StreamTarget,
201) -> String {
202    let mut collected = String::new();
203    let Some(mut pipe) = pipe else {
204        return collected;
205    };
206
207    let mut buf = [0u8; 4096];
208    // Bytes from the tail of the previous read that form an incomplete UTF-8
209    // sequence. At most 3 bytes (the longest incomplete prefix of a 4-byte char).
210    let mut leftover = Vec::new();
211
212    loop {
213        let n = match pipe.read(&mut buf) {
214            Ok(0) => break,
215            Ok(n) => n,
216            Err(_err) => break,
217        };
218
219        if let Ok(mut last) = last_activity.lock() {
220            *last = Instant::now();
221        }
222
223        // Prepend any leftover bytes from the previous chunk.
224        let data = if leftover.is_empty() {
225            &buf[..n]
226        } else {
227            leftover.extend_from_slice(&buf[..n]);
228            leftover.as_slice()
229        };
230
231        // Find the longest valid UTF-8 prefix. Any trailing bytes that form an
232        // incomplete character are saved for the next iteration.
233        let (valid, remaining) = match std::str::from_utf8(data) {
234            Ok(s) => (s, &[][..]),
235            Err(e) => {
236                let valid_up_to = e.valid_up_to();
237                // SAFETY: from_utf8 guarantees bytes up to valid_up_to are valid UTF-8.
238                let valid = unsafe { std::str::from_utf8_unchecked(&data[..valid_up_to]) };
239                (valid, &data[valid_up_to..])
240            }
241        };
242
243        if !valid.is_empty() {
244            match target {
245                StreamTarget::Stdout => {
246                    print!("{valid}");
247                    let _ = std::io::stdout().flush();
248                }
249                StreamTarget::Stderr => {
250                    eprint!("{valid}");
251                    let _ = std::io::stderr().flush();
252                }
253            }
254            collected.push_str(valid);
255        }
256
257        leftover = remaining.to_vec();
258    }
259
260    // Flush any final leftover bytes (incomplete sequence at EOF) as lossy UTF-8.
261    if !leftover.is_empty() {
262        let tail = String::from_utf8_lossy(&leftover);
263        match target {
264            StreamTarget::Stdout => {
265                print!("{tail}");
266                let _ = std::io::stdout().flush();
267            }
268            StreamTarget::Stderr => {
269                eprint!("{tail}");
270                let _ = std::io::stderr().flush();
271            }
272        }
273        collected.push_str(&tail);
274    }
275
276    collected
277}
278
279/// Monitors a child process for inactivity and forcefully terminates it if no activity occurs within `timeout`.
280///
281/// Periodically checks the elapsed time since `last_activity`; if the elapsed time meets or exceeds
282/// `timeout`, prints an inactivity message to stderr, sets `timed_out` to `true`, and attempts to
283/// kill the process with `child_id` (platform-specific: `kill -9` on Unix, `taskkill /F /PID` on Windows).
284/// The monitor exits early if `done` becomes `true` or if `last_activity` cannot be locked.
285///
286/// # Parameters
287///
288/// - `child_id`: process identifier of the child to terminate on timeout.
289/// - `timeout`: duration of allowed inactivity before termination.
290/// - `last_activity`: mutex-protected `Instant` updated by output-streaming threads on each read.
291/// - `timed_out`: atomic flag set to `true` when a timeout-triggered termination occurs.
292/// - `done`: atomic flag that, when set to `true`, stops the monitor loop.
293///
294/// # Examples
295///
296/// ```ignore
297/// use std::sync::{Arc, Mutex, AtomicBool, atomic::Ordering};
298/// use std::time::{Duration, Instant};
299/// use std::thread;
300///
301/// // Prepare shared state
302/// let last_activity = Arc::new(Mutex::new(Instant::now()));
303/// let timed_out = Arc::new(AtomicBool::new(false));
304/// let done = Arc::new(AtomicBool::new(false));
305///
306/// // Clone for the monitor thread
307/// let la = Arc::clone(&last_activity);
308/// let to = Arc::clone(&timed_out);
309/// let dn = Arc::clone(&done);
310///
311/// // Spawn the monitor in a thread (uses a dummy child id 0 for example)
312/// let handle = thread::spawn(move || {
313///     super::monitor_timeout(0, Duration::from_millis(10), &la.lock().unwrap(), &to, &dn);
314/// });
315///
316/// // Signal completion to stop the monitor and join
317/// done.store(true, Ordering::SeqCst);
318/// let _ = handle.join();
319/// ```
320fn monitor_timeout(
321    child_id: u32,
322    timeout: Duration,
323    last_activity: &std::sync::Mutex<Instant>,
324    timed_out: &AtomicBool,
325    done: &AtomicBool,
326) {
327    let check_interval = Duration::from_secs(1);
328
329    loop {
330        thread::sleep(check_interval);
331
332        if done.load(Ordering::SeqCst) {
333            break;
334        }
335
336        let elapsed = match last_activity.lock() {
337            Ok(last) => last.elapsed(),
338            Err(_poisoned) => break,
339        };
340
341        if elapsed >= timeout {
342            eprintln!(
343                "\n=== Inactivity timeout ({:?}) reached, killing process... ===\n",
344                timeout
345            );
346            timed_out.store(true, Ordering::SeqCst);
347
348            #[cfg(unix)]
349            {
350                let _ = std::process::Command::new("kill")
351                    .args(["-9", &child_id.to_string()])
352                    .status();
353            }
354            #[cfg(windows)]
355            {
356                let _ = std::process::Command::new("taskkill")
357                    .args(["/F", "/PID", &child_id.to_string()])
358                    .status();
359            }
360
361            break;
362        }
363    }
364}