use std::fmt;
use std::io;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicBool, Ordering};
use bytes::Bytes;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::{AbortHandle, JoinHandle};
pub trait ChildTerminator: Send + Sync {
fn kill(&mut self) -> io::Result<()>;
}
pub struct PtyHandles {
pub _slave: Option<Box<dyn Send>>,
pub _master: Box<dyn Send>,
}
impl fmt::Debug for PtyHandles {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PtyHandles").finish()
}
}
pub struct ProcessHandle {
writer_tx: mpsc::Sender<Vec<u8>>,
output_tx: broadcast::Sender<Bytes>,
killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
reader_handle: StdMutex<Option<JoinHandle<()>>>,
reader_abort_handles: StdMutex<Vec<AbortHandle>>,
writer_handle: StdMutex<Option<JoinHandle<()>>>,
wait_handle: StdMutex<Option<JoinHandle<()>>>,
exit_status: Arc<AtomicBool>,
exit_code: Arc<StdMutex<Option<i32>>>,
_pty_handles: StdMutex<Option<PtyHandles>>,
}
impl fmt::Debug for ProcessHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProcessHandle")
.field("has_exited", &self.has_exited())
.field("exit_code", &self.exit_code())
.finish()
}
}
impl ProcessHandle {
#[allow(clippy::too_many_arguments)]
pub fn new(
writer_tx: mpsc::Sender<Vec<u8>>,
output_tx: broadcast::Sender<Bytes>,
initial_output_rx: broadcast::Receiver<Bytes>,
killer: Box<dyn ChildTerminator>,
reader_handle: JoinHandle<()>,
reader_abort_handles: Vec<AbortHandle>,
writer_handle: JoinHandle<()>,
wait_handle: JoinHandle<()>,
exit_status: Arc<AtomicBool>,
exit_code: Arc<StdMutex<Option<i32>>>,
pty_handles: Option<PtyHandles>,
) -> (Self, broadcast::Receiver<Bytes>) {
(
Self {
writer_tx,
output_tx,
killer: StdMutex::new(Some(killer)),
reader_handle: StdMutex::new(Some(reader_handle)),
reader_abort_handles: StdMutex::new(reader_abort_handles),
writer_handle: StdMutex::new(Some(writer_handle)),
wait_handle: StdMutex::new(Some(wait_handle)),
exit_status,
exit_code,
_pty_handles: StdMutex::new(pty_handles),
},
initial_output_rx,
)
}
pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
self.writer_tx.clone()
}
pub fn output_receiver(&self) -> broadcast::Receiver<Bytes> {
self.output_tx.subscribe()
}
pub fn has_exited(&self) -> bool {
self.exit_status.load(Ordering::SeqCst)
}
pub fn exit_code(&self) -> Option<i32> {
self.exit_code.lock().ok().and_then(|guard| *guard)
}
pub fn is_output_drained(&self) -> bool {
self.reader_handle
.lock()
.ok()
.and_then(|guard| guard.as_ref().map(JoinHandle::is_finished))
.unwrap_or(true)
}
pub fn terminate(&self) {
self.terminate_internal();
}
fn terminate_internal(&self) {
if let Ok(mut killer_opt) = self.killer.lock()
&& let Some(mut killer) = killer_opt.take()
{
let _ = killer.kill();
}
self.abort_tasks();
}
fn abort_tasks(&self) {
if let Ok(mut h) = self.reader_handle.lock()
&& let Some(handle) = h.take()
{
handle.abort();
}
if let Ok(mut handles) = self.reader_abort_handles.lock() {
for handle in handles.drain(..) {
handle.abort();
}
}
if let Ok(mut h) = self.writer_handle.lock()
&& let Some(handle) = h.take()
{
handle.abort();
}
if let Ok(mut h) = self.wait_handle.lock()
&& let Some(handle) = h.take()
{
handle.abort();
}
}
pub fn is_running(&self) -> bool {
!self.has_exited() && !self.is_writer_closed()
}
pub async fn write(
&self,
bytes: impl Into<Vec<u8>>,
) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
self.writer_tx.send(bytes.into()).await
}
pub fn is_writer_closed(&self) -> bool {
self.writer_tx.is_closed()
}
}
impl Drop for ProcessHandle {
fn drop(&mut self) {
self.terminate_internal();
}
}
#[derive(Debug)]
pub struct SpawnedProcess {
pub session: ProcessHandle,
pub output_rx: broadcast::Receiver<Bytes>,
pub exit_rx: oneshot::Receiver<i32>,
}
impl SpawnedProcess {
pub async fn wait_with_output(self, timeout_ms: u64) -> (Vec<u8>, i32) {
collect_output_until_exit(self.output_rx, self.exit_rx, timeout_ms).await
}
}
pub async fn collect_output_until_exit(
mut output_rx: broadcast::Receiver<Bytes>,
exit_rx: oneshot::Receiver<i32>,
timeout_ms: u64,
) -> (Vec<u8>, i32) {
let mut collected = Vec::new();
let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
tokio::pin!(exit_rx);
loop {
tokio::select! {
res = output_rx.recv() => {
if let Ok(chunk) = res {
collected.extend_from_slice(&chunk);
}
}
res = &mut exit_rx => {
let code = res.unwrap_or(-1);
let quiet = tokio::time::Duration::from_millis(50);
let max_deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
while tokio::time::Instant::now() < max_deadline {
match tokio::time::timeout(quiet, output_rx.recv()).await {
Ok(Ok(chunk)) => collected.extend_from_slice(&chunk),
Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue,
Ok(Err(broadcast::error::RecvError::Closed)) => break,
Err(_) => break, }
}
return (collected, code);
}
_ = tokio::time::sleep_until(deadline) => {
return (collected, -1);
}
}
}
}
pub type ExecCommandSession = ProcessHandle;
pub type SpawnedPty = SpawnedProcess;
#[cfg(test)]
mod tests {
use super::*;
struct NoopTerminator;
impl ChildTerminator for NoopTerminator {
fn kill(&mut self) -> io::Result<()> {
Ok(())
}
}
#[tokio::test]
async fn test_process_handle_debug() {
let exit_status = Arc::new(AtomicBool::new(false));
let exit_code = Arc::new(StdMutex::new(None));
let (writer_tx, _) = mpsc::channel(1);
let (output_tx, initial_rx) = broadcast::channel(1);
let (handle, _) = ProcessHandle::new(
writer_tx,
output_tx,
initial_rx,
Box::new(NoopTerminator),
tokio::spawn(async {}),
vec![],
tokio::spawn(async {}),
tokio::spawn(async {}),
exit_status,
exit_code,
None,
);
let debug_str = format!("{handle:?}");
assert!(debug_str.contains("ProcessHandle"));
}
#[tokio::test]
async fn test_has_exited() {
let exit_status = Arc::new(AtomicBool::new(false));
let exit_code = Arc::new(StdMutex::new(None));
let (writer_tx, _) = mpsc::channel(1);
let (output_tx, initial_rx) = broadcast::channel(1);
let (handle, _) = ProcessHandle::new(
writer_tx,
output_tx,
initial_rx,
Box::new(NoopTerminator),
tokio::spawn(async {}),
vec![],
tokio::spawn(async {}),
tokio::spawn(async {}),
Arc::clone(&exit_status),
exit_code,
None,
);
assert!(!handle.has_exited());
exit_status.store(true, Ordering::SeqCst);
assert!(handle.has_exited());
}
}