use std::error::Error;
use std::ffi::OsStr;
use std::marker;
use std::process::Stdio;
use log;
use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
use tokio::process::Child;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::time::Duration;
#[derive(Debug)]
pub enum LogType {
Info,
Error,
}
pub trait ProcessStatus<T, E>
where
E: Error + Send,
Self: Send,
{
fn status_entry(&self) -> T;
fn status_exit(&self) -> T;
fn error_type(&self) -> E;
fn timeout_error(&self) -> E {
self.error_type()
}
fn wrap_error<F: Error + Sync + Send + 'static>(&self, error: F, message: Option<String>) -> E;
}
#[derive(Debug)]
pub struct LogOutputData {
line: String,
log_type: LogType,
}
#[allow(async_fn_in_trait)]
pub trait AsyncCommand<S, E, P>
where
E: Error + Send,
P: ProcessStatus<S, E> + Send,
Self: Sized,
{
fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
where
A: IntoIterator<Item = B>,
B: AsRef<OsStr>;
async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E>;
}
pub struct AsyncCommandExecutor<S, E, P>
where
S: Send,
E: Error + Send,
P: ProcessStatus<S, E>,
Self: Send,
{
_command: tokio::process::Command,
process: Child,
process_type: P,
_marker_s: marker::PhantomData<S>,
_marker_e: marker::PhantomData<E>,
}
impl<S, E, P> AsyncCommandExecutor<S, E, P>
where
S: Send,
E: Error + Send,
P: ProcessStatus<S, E> + Send,
{
fn init(command: &mut tokio::process::Command, process_type: &P) -> Result<Child, E> {
command
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|_| process_type.error_type())
}
fn generate_command<A, B>(executable_path: &OsStr, args: A) -> tokio::process::Command
where
A: IntoIterator<Item = B>,
B: AsRef<OsStr>,
{
let mut command = tokio::process::Command::new(executable_path);
command.args(args);
command
}
async fn handle_output<R: AsyncRead + Unpin>(data: R, sender: Sender<LogOutputData>) {
let mut lines = BufReader::new(data).lines();
loop {
match lines.next_line().await {
Ok(Some(line)) => {
let io_data = LogOutputData {
line,
log_type: LogType::Info,
};
if sender.send(io_data).await.is_err() {
log::warn!("process output channel closed before stream ended");
break;
}
}
Ok(None) => break,
Err(e) => {
log::error!("Error reading process output: {}", e);
break;
}
}
}
}
async fn log_output(mut receiver: Receiver<LogOutputData>) {
while let Some(data) = receiver.recv().await {
match data.log_type {
LogType::Info => {
log::info!("{}", data.line);
}
LogType::Error => {
log::error!("{}", data.line);
}
}
}
}
async fn run_process(&mut self) -> Result<S, E> {
let exit_status = self
.process
.wait()
.await
.map_err(|e| self.process_type.wrap_error(e, None))?;
if exit_status.success() {
Ok(self.process_type.status_exit())
} else {
Err(self.process_type.error_type())
}
}
async fn command_execution(&mut self) -> Result<S, E> {
let (sender, receiver) = tokio::sync::mpsc::channel::<LogOutputData>(1000);
let res = self.run_process().await;
if let Some(stdout) = self.process.stdout.take() {
let tx = sender.clone();
drop(tokio::task::spawn(async move {
Self::handle_output(stdout, tx).await;
}));
}
if let Some(stderr) = self.process.stderr.take() {
let tx = sender.clone();
drop(tokio::task::spawn(async move {
Self::handle_output(stderr, tx).await;
}));
}
drop(sender);
drop(tokio::task::spawn(async {
Self::log_output(receiver).await;
}));
res
}
}
impl<S, E, P> AsyncCommand<S, E, P> for AsyncCommandExecutor<S, E, P>
where
S: Send,
E: Error + Send,
P: ProcessStatus<S, E> + Send,
{
fn new<A, B>(executable_path: &OsStr, args: A, process_type: P) -> Result<Self, E>
where
A: IntoIterator<Item = B>,
B: AsRef<OsStr>,
{
let mut _command = Self::generate_command(executable_path, args);
let process = Self::init(&mut _command, &process_type)?;
Ok(AsyncCommandExecutor {
_command,
process,
process_type,
_marker_s: Default::default(),
_marker_e: Default::default(),
})
}
async fn execute(&mut self, timeout: Option<Duration>) -> Result<S, E> {
match timeout {
None => self.command_execution().await,
Some(duration) => tokio::time::timeout(duration, self.command_execution())
.await
.map_err(|_| self.process_type.timeout_error())?,
}
}
}