1use std::fmt;
9use std::io;
10use std::sync::Arc;
11use std::sync::Mutex as StdMutex;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use bytes::Bytes;
15use tokio::sync::{broadcast, mpsc, oneshot};
16use tokio::task::{AbortHandle, JoinHandle};
17
18pub trait ChildTerminator: Send + Sync {
22 fn kill(&mut self) -> io::Result<()>;
24}
25
26pub struct PtyHandles {
31 pub _slave: Option<Box<dyn Send>>,
33 pub _master: Box<dyn Send>,
35}
36
37impl fmt::Debug for PtyHandles {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 f.debug_struct("PtyHandles").finish()
40 }
41}
42
43pub struct ProcessHandle {
51 writer_tx: mpsc::Sender<Vec<u8>>,
52 output_tx: broadcast::Sender<Bytes>,
53 killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
54 reader_handle: StdMutex<Option<JoinHandle<()>>>,
55 reader_abort_handles: StdMutex<Vec<AbortHandle>>,
56 writer_handle: StdMutex<Option<JoinHandle<()>>>,
57 wait_handle: StdMutex<Option<JoinHandle<()>>>,
58 exit_status: Arc<AtomicBool>,
59 exit_code: Arc<StdMutex<Option<i32>>>,
60 _pty_handles: StdMutex<Option<PtyHandles>>,
62}
63
64impl fmt::Debug for ProcessHandle {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("ProcessHandle")
67 .field("has_exited", &self.has_exited())
68 .field("exit_code", &self.exit_code())
69 .finish()
70 }
71}
72
73impl ProcessHandle {
74 #[allow(clippy::too_many_arguments)]
76 pub fn new(
77 writer_tx: mpsc::Sender<Vec<u8>>,
78 output_tx: broadcast::Sender<Bytes>,
79 initial_output_rx: broadcast::Receiver<Bytes>,
80 killer: Box<dyn ChildTerminator>,
81 reader_handle: JoinHandle<()>,
82 reader_abort_handles: Vec<AbortHandle>,
83 writer_handle: JoinHandle<()>,
84 wait_handle: JoinHandle<()>,
85 exit_status: Arc<AtomicBool>,
86 exit_code: Arc<StdMutex<Option<i32>>>,
87 pty_handles: Option<PtyHandles>,
88 ) -> (Self, broadcast::Receiver<Bytes>) {
89 (
90 Self {
91 writer_tx,
92 output_tx,
93 killer: StdMutex::new(Some(killer)),
94 reader_handle: StdMutex::new(Some(reader_handle)),
95 reader_abort_handles: StdMutex::new(reader_abort_handles),
96 writer_handle: StdMutex::new(Some(writer_handle)),
97 wait_handle: StdMutex::new(Some(wait_handle)),
98 exit_status,
99 exit_code,
100 _pty_handles: StdMutex::new(pty_handles),
101 },
102 initial_output_rx,
103 )
104 }
105
106 #[inline]
114 pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
115 self.writer_tx.clone()
116 }
117
118 #[inline]
123 pub fn output_receiver(&self) -> broadcast::Receiver<Bytes> {
124 self.output_tx.subscribe()
125 }
126
127 #[inline]
129 pub fn has_exited(&self) -> bool {
130 self.exit_status.load(Ordering::SeqCst)
131 }
132
133 #[inline]
135 pub fn exit_code(&self) -> Option<i32> {
136 self.exit_code.lock().ok().and_then(|guard| *guard)
137 }
138
139 #[inline]
141 pub fn is_output_drained(&self) -> bool {
142 self.reader_handle
143 .lock()
144 .ok()
145 .and_then(|guard| guard.as_ref().map(JoinHandle::is_finished))
146 .unwrap_or(true)
147 }
148
149 pub fn terminate(&self) {
153 self.terminate_internal();
154 }
155
156 fn terminate_internal(&self) {
158 if let Ok(mut killer_opt) = self.killer.lock()
160 && let Some(mut killer) = killer_opt.take()
161 {
162 let _ = killer.kill();
163 }
164
165 self.abort_tasks();
166 }
167
168 fn abort_tasks(&self) {
170 if let Ok(mut h) = self.reader_handle.lock()
172 && let Some(handle) = h.take()
173 {
174 handle.abort();
175 }
176
177 if let Ok(mut handles) = self.reader_abort_handles.lock() {
179 for handle in handles.drain(..) {
180 handle.abort();
181 }
182 }
183
184 if let Ok(mut h) = self.writer_handle.lock()
186 && let Some(handle) = h.take()
187 {
188 handle.abort();
189 }
190
191 if let Ok(mut h) = self.wait_handle.lock()
193 && let Some(handle) = h.take()
194 {
195 handle.abort();
196 }
197 }
198
199 #[inline]
201 pub fn is_running(&self) -> bool {
202 !self.has_exited() && !self.is_writer_closed()
203 }
204
205 pub async fn write(
209 &self,
210 bytes: impl Into<Vec<u8>>,
211 ) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
212 self.writer_tx.send(bytes.into()).await
213 }
214
215 #[inline]
217 pub fn is_writer_closed(&self) -> bool {
218 self.writer_tx.is_closed()
219 }
220}
221
222impl Drop for ProcessHandle {
223 fn drop(&mut self) {
224 self.terminate_internal();
225 }
226}
227
228#[derive(Debug)]
232pub struct SpawnedProcess {
233 pub session: ProcessHandle,
235 pub output_rx: broadcast::Receiver<Bytes>,
237 pub exit_rx: oneshot::Receiver<i32>,
239}
240
241impl SpawnedProcess {
242 pub async fn wait_with_output(self, timeout_ms: u64) -> (Vec<u8>, i32) {
246 collect_output_until_exit(self.output_rx, self.exit_rx, timeout_ms).await
247 }
248}
249
250pub async fn collect_output_until_exit(
254 mut output_rx: broadcast::Receiver<Bytes>,
255 exit_rx: oneshot::Receiver<i32>,
256 timeout_ms: u64,
257) -> (Vec<u8>, i32) {
258 let mut collected = Vec::new();
259 let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
260 tokio::pin!(exit_rx);
261
262 loop {
263 tokio::select! {
264 res = output_rx.recv() => {
265 if let Ok(chunk) = res {
266 collected.extend_from_slice(&chunk);
267 }
268 }
269 res = &mut exit_rx => {
270 let code = res.unwrap_or(-1);
271 let quiet = tokio::time::Duration::from_millis(50);
273 let max_deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
274
275 while tokio::time::Instant::now() < max_deadline {
276 match tokio::time::timeout(quiet, output_rx.recv()).await {
277 Ok(Ok(chunk)) => collected.extend_from_slice(&chunk),
278 Ok(Err(broadcast::error::RecvError::Lagged(count))) => {
279 eprintln!("[vtcode] output stream lagged ({} dropped)", count);
280 continue;
281 }
282 Ok(Err(broadcast::error::RecvError::Closed)) => break,
283 Err(_) => break, }
285 }
286 return (collected, code);
287 }
288 _ = tokio::time::sleep_until(deadline) => {
289 return (collected, -1);
290 }
291 }
292 }
293}
294
295pub type ExecCommandSession = ProcessHandle;
297
298pub type SpawnedPty = SpawnedProcess;
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 struct NoopTerminator;
306 impl ChildTerminator for NoopTerminator {
307 fn kill(&mut self) -> io::Result<()> {
308 Ok(())
309 }
310 }
311
312 #[tokio::test]
313 async fn test_process_handle_debug() {
314 let exit_status = Arc::new(AtomicBool::new(false));
316 let exit_code = Arc::new(StdMutex::new(None));
317
318 let (writer_tx, _) = mpsc::channel(1);
319 let (output_tx, initial_rx) = broadcast::channel(1);
320
321 let (handle, _) = ProcessHandle::new(
322 writer_tx,
323 output_tx,
324 initial_rx,
325 Box::new(NoopTerminator),
326 tokio::spawn(async {}),
327 vec![],
328 tokio::spawn(async {}),
329 tokio::spawn(async {}),
330 exit_status,
331 exit_code,
332 None,
333 );
334
335 let debug_str = format!("{handle:?}");
336 assert!(debug_str.contains("ProcessHandle"));
337 }
338
339 #[tokio::test]
340 async fn test_has_exited() {
341 let exit_status = Arc::new(AtomicBool::new(false));
342 let exit_code = Arc::new(StdMutex::new(None));
343
344 let (writer_tx, _) = mpsc::channel(1);
345 let (output_tx, initial_rx) = broadcast::channel(1);
346
347 let (handle, _) = ProcessHandle::new(
348 writer_tx,
349 output_tx,
350 initial_rx,
351 Box::new(NoopTerminator),
352 tokio::spawn(async {}),
353 vec![],
354 tokio::spawn(async {}),
355 tokio::spawn(async {}),
356 Arc::clone(&exit_status),
357 exit_code,
358 None,
359 );
360
361 assert!(!handle.has_exited());
362 exit_status.store(true, Ordering::SeqCst);
363 assert!(handle.has_exited());
364 }
365}