#![doc = include_str!("../README.md")]
use anyhow::{Result, anyhow};
use log::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::fmt::{Debug, Formatter};
use std::path::PathBuf;
use std::sync::{Arc, OnceLock};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::sync::mpsc::Sender;
use tokio::sync::broadcast;
static PROCESS_POOL: OnceLock<Arc<Mutex<HashMap<u32, AsynchronousInteractiveProcess>>>> = OnceLock::new();
#[derive(Debug)]
pub struct ProcessHandle {
pid: u32,
receiver: broadcast::Receiver<String>,
}
#[derive(Serialize, Deserialize, Default)]
pub struct AsynchronousInteractiveProcess {
pub pid: Option<u32>,
pub filename: String,
pub arguments: Vec<String>,
pub working_directory: PathBuf,
#[serde(skip)]
sender: Option<Sender<String>>,
#[serde(skip)]
output_broadcaster: Option<broadcast::Sender<String>>,
#[serde(skip)]
_keep_alive_receiver: Option<broadcast::Receiver<String>>,
#[serde(skip)]
input_queue: VecDeque<String>,
#[serde(skip)]
exit_callback: Option<Arc<dyn Fn(i32) + Send + Sync>>,
}
impl Debug for AsynchronousInteractiveProcess {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AsynchronousInteractiveProcess {{ pid: {:?}, filename: {:?}, arguments: {:?}, working_directory: {:?} }}",
self.pid, self.filename, self.arguments, self.working_directory
)
}
}
impl ProcessHandle {
const DEFAULT_TIMEOUT_MS: u64 = 100;
pub async fn receive_output(&mut self) -> Result<Option<String>> {
self.receive_output_with_timeout(std::time::Duration::from_millis(Self::DEFAULT_TIMEOUT_MS)).await
}
pub async fn receive_output_with_timeout(&mut self, timeout: std::time::Duration) -> Result<Option<String>> {
const RETRY_DELAY_MS: u64 = 10;
match self.receiver.try_recv() {
Ok(msg) => return Ok(Some(msg)),
Err(broadcast::error::TryRecvError::Empty) => {}
Err(broadcast::error::TryRecvError::Closed) => {
return Err(anyhow!("Broadcast channel closed for process {}", self.pid));
}
Err(broadcast::error::TryRecvError::Lagged(_)) => {
}
}
let start_time = std::time::Instant::now();
while start_time.elapsed() < timeout {
tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await;
match self.receiver.try_recv() {
Ok(msg) => return Ok(Some(msg)),
Err(broadcast::error::TryRecvError::Empty) => {}
Err(broadcast::error::TryRecvError::Closed) => {
return Err(anyhow!("Broadcast channel closed for process {}", self.pid));
}
Err(broadcast::error::TryRecvError::Lagged(_)) => {
}
}
}
Ok(None)
}
pub async fn send_input(&self, input: impl Into<String>) -> Result<()> {
let input_str = input.into();
if let Some(process_pool) = PROCESS_POOL.get() {
let mut pool = process_pool.lock().await;
if let Some(process) = pool.get_mut(&self.pid) {
if let Some(sender) = &process.sender {
match sender.try_send(input_str.clone()) {
Ok(_) => Ok(()),
Err(e) => match e {
tokio::sync::mpsc::error::TrySendError::Full(_) => {
process.input_queue.push_back(input_str);
Ok(())
}
tokio::sync::mpsc::error::TrySendError::Closed(_) => {
Err(anyhow!("Failed to send input: channel closed"))
}
},
}
} else {
Err(anyhow!("Process not started or sender not available"))
}
} else {
Err(anyhow!("Process not found"))
}
} else {
Err(anyhow!("Process pool not initialized"))
}
}
pub async fn is_process_running(&self) -> bool {
if let Some(process_pool) = PROCESS_POOL.get() {
let pool = process_pool.lock().await;
return pool.contains_key(&self.pid);
}
false
}
pub async fn shutdown(&self, timeout: std::time::Duration) -> Result<()> {
let graceful_shutdown_result = self.graceful_shutdown().await;
if let Err(e) = &graceful_shutdown_result {
debug!("Graceful shutdown attempt failed: {}", e);
return self.kill().await;
}
let start_time = std::time::Instant::now();
while start_time.elapsed() < timeout {
if !self.is_process_running().await {
return Ok(());
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
debug!("Process did not exit gracefully within timeout, forcing termination");
self.kill().await
}
async fn graceful_shutdown(&self) -> Result<()> {
#[cfg(target_os = "windows")]
{
unsafe {
let result = winapi::um::wincon::GenerateConsoleCtrlEvent(0, self.pid);
if result == 0 {
return Err(anyhow!(
"Failed to send Ctrl+C to process {}: {}",
self.pid,
std::io::Error::last_os_error()
));
}
}
return Ok(());
}
#[cfg(target_os = "linux")]
{
unsafe {
let result = libc::kill(self.pid as libc::pid_t, libc::SIGTERM);
if result != 0 {
return Err(anyhow!(
"Failed to send SIGTERM to process {}: {}",
self.pid,
std::io::Error::last_os_error()
));
}
}
return Ok(());
}
#[cfg(not(any(target_os = "windows", target_os = "linux")))]
{
return Err(anyhow!("Graceful shutdown not implemented for this platform"));
}
}
pub async fn kill(&self) -> Result<()> {
#[cfg(target_os = "windows")]
{
unsafe {
let handle = winapi::um::processthreadsapi::OpenProcess(0x00010000, 0, self.pid);
if handle.is_null() {
return Err(anyhow!("Failed to open process {}: {}", self.pid, std::io::Error::last_os_error()));
}
let result = winapi::um::processthreadsapi::TerminateProcess(handle, 0);
let close_result = winapi::um::handleapi::CloseHandle(handle);
if result == 0 {
return Err(anyhow!(
"Failed to terminate process {}: {}",
self.pid,
std::io::Error::last_os_error()
));
}
if close_result == 0 {
warn!(
"Failed to close process handle for process {}: {}",
self.pid,
std::io::Error::last_os_error()
);
}
}
}
#[cfg(target_os = "linux")]
{
unsafe {
let result = libc::kill(self.pid as libc::pid_t, libc::SIGKILL);
if result != 0 {
return Err(anyhow!("Failed to kill process {}: {}", self.pid, std::io::Error::last_os_error()));
}
}
}
Ok(())
}
}
impl AsynchronousInteractiveProcess {
pub fn new(filename: impl Into<String>) -> Self {
Self {
pid: None,
filename: filename.into(),
arguments: Vec::new(),
working_directory: PathBuf::from("./"),
sender: None,
output_broadcaster: None,
_keep_alive_receiver: None,
input_queue: VecDeque::new(),
exit_callback: None,
}
}
pub fn with_arguments(mut self, args: Vec<impl Into<String>>) -> Self {
self.arguments = args.into_iter().map(|arg| arg.into()).collect();
self
}
pub fn with_argument(mut self, arg: impl Into<String>) -> Self {
self.arguments.push(arg.into());
self
}
pub fn with_working_directory(mut self, dir: impl Into<PathBuf>) -> Self {
self.working_directory = dir.into();
self
}
pub fn process_exit_callback<F>(mut self, callback: F) -> Self
where
F: Fn(i32) + Send + Sync + 'static,
{
self.exit_callback = Some(Arc::new(callback));
self
}
pub async fn start(&mut self) -> Result<u32> {
let mut command = Command::new(&self.filename);
command.args(&self.arguments);
let working_dir = if cfg!(windows) {
let path_str = self.working_directory.to_string_lossy();
if path_str.starts_with(r"\\?\") {
PathBuf::from(&path_str[4..]) } else {
self.working_directory.clone()
}
} else {
self.working_directory.clone()
};
debug!("tokio-interactive: filename = {}", self.filename);
debug!("tokio-interactive: arguments = {:?}", self.arguments);
debug!("tokio-interactive: working_directory = {:?}", working_dir);
debug!("tokio-interactive: working_directory exists = {}", working_dir.exists());
debug!("tokio-interactive: working_directory is_dir = {}", working_dir.is_dir());
debug!("tokio-interactive: working_directory is_absolute = {}", working_dir.is_absolute());
if let Ok(current_dir) = std::env::current_dir() {
debug!("tokio-interactive: current working directory before setting = {:?}", current_dir);
}
command.current_dir(working_dir);
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
command.stderr(std::process::Stdio::piped());
debug!("tokio-interactive: Final command = {:?}", command);
let mut child = command.spawn()?;
let pid = child.id().unwrap_or(0);
debug!("tokio-interactive: Process spawned with PID = {}", pid);
let (stdin_sender, mut stdin_receiver) = tokio::sync::mpsc::channel::<String>(100);
let (stdout_broadcaster, _keep_alive_receiver) = broadcast::channel::<String>(100);
if let Some(mut stdin) = child.stdin.take() {
tokio::spawn(async move {
while let Some(input) = stdin_receiver.recv().await {
if let Err(e) = stdin.write_all(input.as_bytes()).await {
error!("Failed to write to process stdin: {}", e);
break;
}
if let Err(e) = stdin.write_all(b"\n").await {
error!("Failed to write newline to process stdin: {}", e);
break;
}
if let Err(e) = stdin.flush().await {
error!("Failed to flush process stdin: {}", e);
break;
}
}
});
}
if let Some(stdout) = child.stdout.take() {
let stdout_broadcaster_clone = stdout_broadcaster.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
while let Ok(bytes_read) = reader.read_line(&mut line).await {
if bytes_read == 0 {
break;
}
if let Err(e) = stdout_broadcaster_clone.send(line.trim_end().to_string()) {
error!("Failed to broadcast stdout message: {}", e);
break;
}
line.clear();
}
});
}
if let Some(stderr) = child.stderr.take() {
let stderr_broadcaster = stdout_broadcaster.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut line = String::new();
while let Ok(bytes_read) = reader.read_line(&mut line).await {
if bytes_read == 0 {
break;
}
if let Err(e) = stderr_broadcaster.send(format!("STDERR: {}", line.trim_end())) {
error!("Failed to broadcast stderr message: {}", e);
break;
}
line.clear();
}
});
}
let queue_pid = pid;
tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let process_pool = match PROCESS_POOL.get() {
Some(pool) => pool,
None => break,
};
let mut pool = process_pool.lock().await;
let process = match pool.get_mut(&queue_pid) {
Some(p) => p,
None => break,
};
while let Some(input) = process.input_queue.pop_front() {
if let Some(sender) = &process.sender {
if let Err(_) = sender.try_send(input.clone()) {
process.input_queue.push_front(input);
break;
}
} else {
process.input_queue.clear();
break;
}
}
drop(pool);
}
});
let exit_callback = self.exit_callback.clone();
tokio::spawn(async move {
match child.wait().await {
Ok(exit_status) => {
let exit_code = exit_status.code().unwrap_or(-1); debug!("Process {} exited with code: {}", pid, exit_code);
if let Some(exit_callback) = exit_callback {
exit_callback(exit_code);
}
}
Err(e) => {
error!("Process {} exited with error: {}", pid, e);
if let Some(exit_callback) = exit_callback {
exit_callback(-1); }
}
}
if let Some(process_pool) = PROCESS_POOL.get() {
let mut pool = process_pool.lock().await;
pool.remove(&pid);
debug!("Process {} has exited and been removed from the pool", pid);
}
});
let process_pool = PROCESS_POOL.get_or_init(|| Arc::new(Mutex::new(HashMap::new())));
let mut pool = process_pool.lock().await;
pool.insert(
pid,
Self {
pid: Some(pid),
filename: self.filename.clone(),
arguments: self.arguments.clone(),
working_directory: self.working_directory.clone(),
sender: Some(stdin_sender),
output_broadcaster: Some(stdout_broadcaster.clone()),
_keep_alive_receiver: Some(_keep_alive_receiver),
input_queue: VecDeque::new(),
exit_callback: self.exit_callback.clone(),
},
);
self.pid = Some(pid);
Ok(pid)
}
pub async fn get_process_by_pid(pid: u32) -> Option<ProcessHandle> {
let process_pool = PROCESS_POOL.get()?;
let pool = process_pool.lock().await;
if let Some(process) = pool.get(&pid) {
if let Some(broadcaster) = &process.output_broadcaster {
let receiver = broadcaster.subscribe();
Some(ProcessHandle { pid, receiver })
} else {
None
}
} else {
None
}
}
pub async fn is_process_running(&self) -> bool {
if let Some(pid) = self.pid {
if let Some(process_pool) = PROCESS_POOL.get() {
let pool = process_pool.lock().await;
return pool.contains_key(&pid);
}
}
false
}
}