Skip to main content

vtcode_bash_runner/
process.rs

1//! Unified process handle types for PTY and pipe backends.
2//!
3//! This module provides abstractions for interacting with spawned processes
4//! regardless of whether they use a PTY or regular pipes.
5//!
6//! Inspired by codex-rs/utils/pty process handle patterns.
7
8use std::fmt;
9use std::io;
10use std::sync::Arc;
11use std::sync::Mutex as StdMutex;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use bytes::Bytes;
15use tokio::sync::{broadcast, mpsc, oneshot};
16use tokio::task::{AbortHandle, JoinHandle};
17
18/// Trait for process termination strategies.
19///
20/// Different backends (PTY vs pipe) may need different termination approaches.
21pub trait ChildTerminator: Send + Sync {
22    /// Kill the child process.
23    fn kill(&mut self) -> io::Result<()>;
24}
25
26/// Optional PTY-specific handles that must be preserved.
27///
28/// For PTY processes, the slave handle must be kept alive because the process
29/// will receive SIGHUP if it's closed.
30pub struct PtyHandles {
31    /// The slave PTY handle (kept alive to prevent SIGHUP).
32    pub _slave: Option<Box<dyn Send>>,
33    /// The master PTY handle.
34    pub _master: Box<dyn Send>,
35}
36
37impl fmt::Debug for PtyHandles {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        f.debug_struct("PtyHandles").finish()
40    }
41}
42
43/// Handle for driving an interactive or non-interactive process.
44///
45/// This provides a unified interface for both PTY and pipe-based processes:
46/// - Write to stdin via `writer_sender()`
47/// - Read merged stdout/stderr via `output_receiver()`
48/// - Check exit status via `has_exited()` and `exit_code()`
49/// - Clean up via `terminate()`
50pub struct ProcessHandle {
51    writer_tx: mpsc::Sender<Vec<u8>>,
52    output_tx: broadcast::Sender<Bytes>,
53    killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
54    reader_handle: StdMutex<Option<JoinHandle<()>>>,
55    reader_abort_handles: StdMutex<Vec<AbortHandle>>,
56    writer_handle: StdMutex<Option<JoinHandle<()>>>,
57    wait_handle: StdMutex<Option<JoinHandle<()>>>,
58    exit_status: Arc<AtomicBool>,
59    exit_code: Arc<StdMutex<Option<i32>>>,
60    // PTY handles must be preserved to prevent the process from receiving Control+C
61    _pty_handles: StdMutex<Option<PtyHandles>>,
62}
63
64impl fmt::Debug for ProcessHandle {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("ProcessHandle")
67            .field("has_exited", &self.has_exited())
68            .field("exit_code", &self.exit_code())
69            .finish()
70    }
71}
72
73impl ProcessHandle {
74    /// Create a new process handle with all required components.
75    #[allow(clippy::too_many_arguments)]
76    pub fn new(
77        writer_tx: mpsc::Sender<Vec<u8>>,
78        output_tx: broadcast::Sender<Bytes>,
79        initial_output_rx: broadcast::Receiver<Bytes>,
80        killer: Box<dyn ChildTerminator>,
81        reader_handle: JoinHandle<()>,
82        reader_abort_handles: Vec<AbortHandle>,
83        writer_handle: JoinHandle<()>,
84        wait_handle: JoinHandle<()>,
85        exit_status: Arc<AtomicBool>,
86        exit_code: Arc<StdMutex<Option<i32>>>,
87        pty_handles: Option<PtyHandles>,
88    ) -> (Self, broadcast::Receiver<Bytes>) {
89        (
90            Self {
91                writer_tx,
92                output_tx,
93                killer: StdMutex::new(Some(killer)),
94                reader_handle: StdMutex::new(Some(reader_handle)),
95                reader_abort_handles: StdMutex::new(reader_abort_handles),
96                writer_handle: StdMutex::new(Some(writer_handle)),
97                wait_handle: StdMutex::new(Some(wait_handle)),
98                exit_status,
99                exit_code,
100                _pty_handles: StdMutex::new(pty_handles),
101            },
102            initial_output_rx,
103        )
104    }
105
106    /// Returns a channel sender for writing raw bytes to the child stdin.
107    ///
108    /// # Example
109    /// ```ignore
110    /// let writer = handle.writer_sender();
111    /// writer.send(b"input\n".to_vec()).await?;
112    /// ```
113    #[inline]
114    pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
115        self.writer_tx.clone()
116    }
117
118    /// Returns a broadcast receiver that yields stdout/stderr chunks.
119    ///
120    /// Multiple receivers can be created; each receives all output from the
121    /// point of subscription.
122    #[inline]
123    pub fn output_receiver(&self) -> broadcast::Receiver<Bytes> {
124        self.output_tx.subscribe()
125    }
126
127    /// True if the child process has exited.
128    #[inline]
129    pub fn has_exited(&self) -> bool {
130        self.exit_status.load(Ordering::SeqCst)
131    }
132
133    /// Returns the exit code if the process has exited.
134    #[inline]
135    pub fn exit_code(&self) -> Option<i32> {
136        self.exit_code.lock().ok().and_then(|guard| *guard)
137    }
138
139    /// True once the stdout/stderr reader task has drained the child streams.
140    #[inline]
141    pub fn is_output_drained(&self) -> bool {
142        self.reader_handle
143            .lock()
144            .ok()
145            .and_then(|guard| guard.as_ref().map(JoinHandle::is_finished))
146            .unwrap_or(true)
147    }
148
149    /// Attempts to kill the child and abort helper tasks.
150    ///
151    /// This is idempotent and safe to call multiple times.
152    pub fn terminate(&self) {
153        self.terminate_internal();
154    }
155
156    /// Internal termination that aborts all tasks.
157    fn terminate_internal(&self) {
158        // Kill the child process
159        if let Ok(mut killer_opt) = self.killer.lock()
160            && let Some(mut killer) = killer_opt.take()
161        {
162            let _ = killer.kill();
163        }
164
165        self.abort_tasks();
166    }
167
168    /// Abort all background tasks associated with this process.
169    fn abort_tasks(&self) {
170        // Abort reader handle
171        if let Ok(mut h) = self.reader_handle.lock()
172            && let Some(handle) = h.take()
173        {
174            handle.abort();
175        }
176
177        // Abort individual reader abort handles
178        if let Ok(mut handles) = self.reader_abort_handles.lock() {
179            for handle in handles.drain(..) {
180                handle.abort();
181            }
182        }
183
184        // Abort writer handle
185        if let Ok(mut h) = self.writer_handle.lock()
186            && let Some(handle) = h.take()
187        {
188            handle.abort();
189        }
190
191        // Abort wait handle
192        if let Ok(mut h) = self.wait_handle.lock()
193            && let Some(handle) = h.take()
194        {
195            handle.abort();
196        }
197    }
198
199    /// Check if the process is still running.
200    #[inline]
201    pub fn is_running(&self) -> bool {
202        !self.has_exited() && !self.is_writer_closed()
203    }
204
205    /// Send bytes to the process stdin.
206    ///
207    /// Returns an error if the stdin channel is closed.
208    pub async fn write(
209        &self,
210        bytes: impl Into<Vec<u8>>,
211    ) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
212        self.writer_tx.send(bytes.into()).await
213    }
214
215    /// Check if the writer channel is closed.
216    #[inline]
217    pub fn is_writer_closed(&self) -> bool {
218        self.writer_tx.is_closed()
219    }
220}
221
222impl Drop for ProcessHandle {
223    fn drop(&mut self) {
224        self.terminate_internal();
225    }
226}
227
228/// Return value from spawn helpers (PTY or pipe).
229///
230/// Bundles the process handle with receivers for output and exit notification.
231#[derive(Debug)]
232pub struct SpawnedProcess {
233    /// Handle for interacting with the process.
234    pub session: ProcessHandle,
235    /// Receiver for stdout/stderr output chunks.
236    pub output_rx: broadcast::Receiver<Bytes>,
237    /// Receiver for exit code (receives once when process exits).
238    pub exit_rx: oneshot::Receiver<i32>,
239}
240
241impl SpawnedProcess {
242    /// Convenience method to wait for the process to exit and collect output.
243    ///
244    /// Returns (collected_output, exit_code).
245    pub async fn wait_with_output(self, timeout_ms: u64) -> (Vec<u8>, i32) {
246        collect_output_until_exit(self.output_rx, self.exit_rx, timeout_ms).await
247    }
248}
249
250/// Collect output from a process until it exits or times out.
251///
252/// This is useful for tests and simple use cases where you want all output.
253pub async fn collect_output_until_exit(
254    mut output_rx: broadcast::Receiver<Bytes>,
255    exit_rx: oneshot::Receiver<i32>,
256    timeout_ms: u64,
257) -> (Vec<u8>, i32) {
258    let mut collected = Vec::new();
259    let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
260    tokio::pin!(exit_rx);
261
262    loop {
263        tokio::select! {
264            res = output_rx.recv() => {
265                if let Ok(chunk) = res {
266                    collected.extend_from_slice(&chunk);
267                }
268            }
269            res = &mut exit_rx => {
270                let code = res.unwrap_or(-1);
271                // Drain remaining output briefly after exit
272                let quiet = tokio::time::Duration::from_millis(50);
273                let max_deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
274
275                while tokio::time::Instant::now() < max_deadline {
276                    match tokio::time::timeout(quiet, output_rx.recv()).await {
277                        Ok(Ok(chunk)) => collected.extend_from_slice(&chunk),
278                        Ok(Err(broadcast::error::RecvError::Lagged(count))) => {
279                            eprintln!("[vtcode] output stream lagged ({} dropped)", count);
280                            continue;
281                        }
282                        Ok(Err(broadcast::error::RecvError::Closed)) => break,
283                        Err(_) => break, // Timeout - quiet period reached
284                    }
285                }
286                return (collected, code);
287            }
288            _ = tokio::time::sleep_until(deadline) => {
289                return (collected, -1);
290            }
291        }
292    }
293}
294
295/// Backwards-compatible alias for ProcessHandle.
296pub type ExecCommandSession = ProcessHandle;
297
298/// Backwards-compatible alias for SpawnedProcess.
299pub type SpawnedPty = SpawnedProcess;
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    struct NoopTerminator;
306    impl ChildTerminator for NoopTerminator {
307        fn kill(&mut self) -> io::Result<()> {
308            Ok(())
309        }
310    }
311
312    #[tokio::test]
313    async fn test_process_handle_debug() {
314        // Just verify Debug impl doesn't panic
315        let exit_status = Arc::new(AtomicBool::new(false));
316        let exit_code = Arc::new(StdMutex::new(None));
317
318        let (writer_tx, _) = mpsc::channel(1);
319        let (output_tx, initial_rx) = broadcast::channel(1);
320
321        let (handle, _) = ProcessHandle::new(
322            writer_tx,
323            output_tx,
324            initial_rx,
325            Box::new(NoopTerminator),
326            tokio::spawn(async {}),
327            vec![],
328            tokio::spawn(async {}),
329            tokio::spawn(async {}),
330            exit_status,
331            exit_code,
332            None,
333        );
334
335        let debug_str = format!("{handle:?}");
336        assert!(debug_str.contains("ProcessHandle"));
337    }
338
339    #[tokio::test]
340    async fn test_has_exited() {
341        let exit_status = Arc::new(AtomicBool::new(false));
342        let exit_code = Arc::new(StdMutex::new(None));
343
344        let (writer_tx, _) = mpsc::channel(1);
345        let (output_tx, initial_rx) = broadcast::channel(1);
346
347        let (handle, _) = ProcessHandle::new(
348            writer_tx,
349            output_tx,
350            initial_rx,
351            Box::new(NoopTerminator),
352            tokio::spawn(async {}),
353            vec![],
354            tokio::spawn(async {}),
355            tokio::spawn(async {}),
356            Arc::clone(&exit_status),
357            exit_code,
358            None,
359        );
360
361        assert!(!handle.has_exited());
362        exit_status.store(true, Ordering::SeqCst);
363        assert!(handle.has_exited());
364    }
365}