use std::process::Stdio;
use futures::future::Future;
use process_wrap::tokio::{ChildWrapper, CommandWrap};
use tokio::{
io::AsyncRead,
process::{ChildStderr, ChildStdin, ChildStdout},
};
use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport};
use crate::RoleClient;
const MAX_WAIT_ON_DROP_SECS: u64 = 3;
type ChildProcessParts = (
Box<dyn ChildWrapper>,
ChildStdout,
ChildStdin,
Option<ChildStderr>,
);
#[inline]
fn child_process(mut child: Box<dyn ChildWrapper>) -> std::io::Result<ChildProcessParts> {
let child_stdin = match child.inner_mut().stdin().take() {
Some(stdin) => stdin,
None => return Err(std::io::Error::other("stdin was already taken")),
};
let child_stdout = match child.inner_mut().stdout().take() {
Some(stdout) => stdout,
None => return Err(std::io::Error::other("stdout was already taken")),
};
let child_stderr = child.inner_mut().stderr().take();
Ok((child, child_stdout, child_stdin, child_stderr))
}
pub struct TokioChildProcess {
child: ChildWithCleanup,
transport: AsyncRwTransport<RoleClient, ChildStdout, ChildStdin>,
}
pub struct ChildWithCleanup {
inner: Option<Box<dyn ChildWrapper>>,
}
impl Drop for ChildWithCleanup {
fn drop(&mut self) {
if let Some(mut inner) = self.inner.take() {
tokio::spawn(async move {
if let Err(e) = Box::into_pin(inner.kill()).await {
tracing::warn!("Error killing child process: {}", e);
}
});
}
}
}
pin_project_lite::pin_project! {
pub struct TokioChildProcessOut {
child: ChildWithCleanup,
#[pin]
child_stdout: ChildStdout,
}
}
impl TokioChildProcessOut {
pub fn id(&self) -> Option<u32> {
self.child.inner.as_ref()?.id()
}
}
impl AsyncRead for TokioChildProcessOut {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
self.project().child_stdout.poll_read(cx, buf)
}
}
impl TokioChildProcess {
pub fn new(command: impl Into<CommandWrap>) -> std::io::Result<Self> {
let (proc, _ignored) = TokioChildProcessBuilder::new(command).spawn()?;
Ok(proc)
}
pub fn builder(command: impl Into<CommandWrap>) -> TokioChildProcessBuilder {
TokioChildProcessBuilder::new(command)
}
pub fn id(&self) -> Option<u32> {
self.child.inner.as_ref()?.id()
}
pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> {
if let Some(mut child) = self.child.inner.take() {
self.transport.close().await?;
let wait_fut = child.wait();
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => {
if let Err(e) = Box::into_pin(child.kill()).await {
tracing::warn!("Error killing child: {e}");
return Err(e);
}
},
res = wait_fut => {
match res {
Ok(status) => {
tracing::info!("Child exited gracefully {}", status);
}
Err(e) => {
tracing::warn!("Error waiting for child: {e}");
return Err(e);
}
}
}
}
}
Ok(())
}
pub fn into_inner(mut self) -> Option<Box<dyn ChildWrapper>> {
self.child.inner.take()
}
#[deprecated(
since = "0.5.0",
note = "use the Transport trait implementation instead"
)]
pub fn split(self) -> (TokioChildProcessOut, ChildStdin) {
unimplemented!("This method is deprecated, use the Transport trait implementation instead");
}
}
pub struct TokioChildProcessBuilder {
cmd: CommandWrap,
stdin: Stdio,
stdout: Stdio,
stderr: Stdio,
}
impl TokioChildProcessBuilder {
fn new(cmd: impl Into<CommandWrap>) -> Self {
Self {
cmd: cmd.into(),
stdin: Stdio::piped(),
stdout: Stdio::piped(),
stderr: Stdio::inherit(),
}
}
pub fn stdin(mut self, io: impl Into<Stdio>) -> Self {
self.stdin = io.into();
self
}
pub fn stdout(mut self, io: impl Into<Stdio>) -> Self {
self.stdout = io.into();
self
}
pub fn stderr(mut self, io: impl Into<Stdio>) -> Self {
self.stderr = io.into();
self
}
pub fn spawn(mut self) -> std::io::Result<(TokioChildProcess, Option<ChildStderr>)> {
self.cmd
.command_mut()
.stdin(self.stdin)
.stdout(self.stdout)
.stderr(self.stderr);
let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?;
let transport = AsyncRwTransport::new(stdout, stdin);
let proc = TokioChildProcess {
child: ChildWithCleanup { inner: Some(child) },
transport,
};
Ok((proc, stderr_opt))
}
}
impl Transport<RoleClient> for TokioChildProcess {
type Error = std::io::Error;
fn send(
&mut self,
item: TxJsonRpcMessage<RoleClient>,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
self.transport.send(item)
}
fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
self.transport.receive()
}
fn close(&mut self) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.graceful_shutdown()
}
}
pub trait ConfigureCommandExt {
fn configure(self, f: impl FnOnce(&mut Self)) -> Self;
}
impl ConfigureCommandExt for tokio::process::Command {
fn configure(mut self, f: impl FnOnce(&mut Self)) -> Self {
f(&mut self);
self
}
}
#[cfg(feature = "which-command")]
pub fn which_command(
name: impl AsRef<std::ffi::OsStr>,
) -> std::io::Result<tokio::process::Command> {
let resolved = which::which(name.as_ref())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::NotFound, e))?;
Ok(tokio::process::Command::new(resolved))
}
#[cfg(feature = "which-command")]
#[cfg(test)]
mod tests_which {
#[test]
fn which_command_resolves_known_binary() {
#[cfg(unix)]
let result = super::which_command("ls");
#[cfg(windows)]
let result = super::which_command("cmd");
assert!(result.is_ok());
}
#[test]
fn which_command_fails_for_nonexistent() {
let result = super::which_command("this_binary_definitely_does_not_exist_12345");
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::NotFound);
}
}
#[cfg(unix)]
#[cfg(test)]
mod tests {
use tokio::process::Command;
use super::*;
#[tokio::test]
async fn test_tokio_child_process_drop() {
let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
cmd.arg("30");
}));
assert!(r.is_ok());
let child_process = r.unwrap();
let id = child_process.id();
assert!(id.is_some());
let id = id.unwrap();
drop(child_process);
tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
let status = Command::new("ps")
.arg("-p")
.arg(id.to_string())
.status()
.await;
match status {
Ok(status) => {
assert!(
!status.success(),
"Process with PID {} is still running",
id
);
}
Err(e) => {
panic!("Failed to check process status: {}", e);
}
}
}
#[tokio::test]
async fn test_tokio_child_process_graceful_shutdown() {
let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
cmd.arg("30");
}));
assert!(r.is_ok());
let mut child_process = r.unwrap();
let id = child_process.id();
assert!(id.is_some());
let id = id.unwrap();
child_process.graceful_shutdown().await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
let status = Command::new("ps")
.arg("-p")
.arg(id.to_string())
.status()
.await;
match status {
Ok(status) => {
assert!(
!status.success(),
"Process with PID {} is still running",
id
);
}
Err(e) => {
panic!("Failed to check process status: {}", e);
}
}
}
}