#![doc = include_str!("../README.md")]
use anyhow::{Result, anyhow};
use log::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::sync::{Arc, OnceLock};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{Receiver, Sender};
static PROCESS_POOL: OnceLock<Arc<Mutex<HashMap<u32, AsynchronousInteractiveProcess>>>> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct ProcessHandle {
pid: u32,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct AsynchronousInteractiveProcess {
pub pid: Option<u32>,
pub filename: String,
pub arguments: Vec<String>,
pub working_directory: PathBuf,
#[serde(skip)]
sender: Option<Sender<String>>,
#[serde(skip)]
receiver: Option<Receiver<String>>,
#[serde(skip)]
input_queue: VecDeque<String>,
}
impl ProcessHandle {
const DEFAULT_TIMEOUT_MS: u64 = 100;
async fn try_receive_message(
pid: u32,
pool: &mut tokio::sync::MutexGuard<'_, HashMap<u32, AsynchronousInteractiveProcess>>,
) -> Result<Option<String>> {
if let Some(process) = pool.get_mut(&pid) {
if let Some(receiver) = &mut process.receiver {
match receiver.try_recv() {
Ok(msg) => return Ok(Some(msg)),
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {}
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
return Err(anyhow!("Channel disconnected for process {}", pid));
}
}
} else {
return Err(anyhow!("Receiver not available for process {}", pid));
}
} else {
return Err(anyhow!("Process {} not found", pid));
}
Ok(None)
}
pub async fn receive_output(&self) -> Result<Option<String>> {
self.receive_output_with_timeout(std::time::Duration::from_millis(Self::DEFAULT_TIMEOUT_MS)).await
}
pub async fn receive_output_with_timeout(&self, timeout: std::time::Duration) -> Result<Option<String>> {
const RETRY_DELAY_MS: u64 = 10;
if let Some(process_pool) = PROCESS_POOL.get() {
{
let mut pool = process_pool.lock().await;
if let Some(msg) = Self::try_receive_message(self.pid, &mut pool).await? {
return Ok(Some(msg));
}
}
let start_time = std::time::Instant::now();
while start_time.elapsed() < timeout {
tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await;
let mut pool = process_pool.lock().await;
if let Some(msg) = Self::try_receive_message(self.pid, &mut pool).await? {
return Ok(Some(msg));
}
}
return Ok(None);
}
Err(anyhow!("Process pool not initialized"))
}
pub async fn send_input(&self, input: impl Into<String>) -> Result<()> {
let input_str = input.into();
if let Some(process_pool) = PROCESS_POOL.get() {
let mut pool = process_pool.lock().await;
if let Some(process) = pool.get_mut(&self.pid) {
if let Some(sender) = &process.sender {
match sender.try_send(input_str.clone()) {
Ok(_) => Ok(()),
Err(e) => match e {
tokio::sync::mpsc::error::TrySendError::Full(_) => {
process.input_queue.push_back(input_str);
Ok(())
}
tokio::sync::mpsc::error::TrySendError::Closed(_) => {
Err(anyhow!("Failed to send input: channel closed"))
}
},
}
} else {
Err(anyhow!("Process not started or sender not available"))
}
} else {
Err(anyhow!("Process not found"))
}
} else {
Err(anyhow!("Process pool not initialized"))
}
}
pub async fn is_process_running(&self) -> bool {
if let Some(process_pool) = PROCESS_POOL.get() {
let pool = process_pool.lock().await;
return pool.contains_key(&self.pid);
}
false
}
pub async fn shutdown(&self, timeout: std::time::Duration) -> Result<()> {
let graceful_shutdown_result = self.graceful_shutdown().await;
if let Err(e) = &graceful_shutdown_result {
debug!("Graceful shutdown attempt failed: {}", e);
return self.kill().await;
}
let start_time = std::time::Instant::now();
while start_time.elapsed() < timeout {
if !self.is_process_running().await {
return Ok(());
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
debug!("Process did not exit gracefully within timeout, forcing termination");
self.kill().await
}
async fn graceful_shutdown(&self) -> Result<()> {
#[cfg(target_os = "windows")]
{
unsafe {
let result = winapi::um::wincon::GenerateConsoleCtrlEvent(0, self.pid);
if result == 0 {
return Err(anyhow!(
"Failed to send Ctrl+C to process {}: {}",
self.pid,
std::io::Error::last_os_error()
));
}
}
return Ok(());
}
#[cfg(target_os = "linux")]
{
unsafe {
let result = libc::kill(self.pid as libc::pid_t, libc::SIGTERM);
if result != 0 {
return Err(anyhow!(
"Failed to send SIGTERM to process {}: {}",
self.pid,
std::io::Error::last_os_error()
));
}
}
return Ok(());
}
#[cfg(not(any(target_os = "windows", target_os = "linux")))]
{
return Err(anyhow!("Graceful shutdown not implemented for this platform"));
}
}
pub async fn kill(&self) -> Result<()> {
#[cfg(target_os = "windows")]
{
unsafe {
let handle = winapi::um::processthreadsapi::OpenProcess(0x00010000, 0, self.pid);
if handle.is_null() {
return Err(anyhow!("Failed to open process {}: {}", self.pid, std::io::Error::last_os_error()));
}
let result = winapi::um::processthreadsapi::TerminateProcess(handle, 0);
let close_result = winapi::um::handleapi::CloseHandle(handle);
if result == 0 {
return Err(anyhow!(
"Failed to terminate process {}: {}",
self.pid,
std::io::Error::last_os_error()
));
}
if close_result == 0 {
warn!(
"Failed to close process handle for process {}: {}",
self.pid,
std::io::Error::last_os_error()
);
}
}
}
#[cfg(target_os = "linux")]
{
unsafe {
let result = libc::kill(self.pid as libc::pid_t, libc::SIGKILL);
if result != 0 {
return Err(anyhow!("Failed to kill process {}: {}", self.pid, std::io::Error::last_os_error()));
}
}
}
Ok(())
}
}
impl AsynchronousInteractiveProcess {
pub fn new(filename: impl Into<String>) -> Self {
Self {
pid: None,
filename: filename.into(),
arguments: Vec::new(),
working_directory: PathBuf::from("./"),
sender: None,
receiver: None,
input_queue: VecDeque::new(),
}
}
pub fn with_arguments(mut self, args: Vec<impl Into<String>>) -> Self {
self.arguments = args.into_iter().map(|arg| arg.into()).collect();
self
}
pub fn with_argument(mut self, arg: impl Into<String>) -> Self {
self.arguments.push(arg.into());
self
}
pub fn with_working_directory(mut self, dir: impl Into<PathBuf>) -> Self {
self.working_directory = dir.into();
self
}
pub async fn start(&mut self) -> Result<u32> {
let mut command = Command::new(&self.filename);
command.args(&self.arguments);
command.current_dir(&self.working_directory);
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
command.stderr(std::process::Stdio::piped());
let mut child = command.spawn()?;
let pid = child.id().unwrap_or(0);
let (stdin_sender, mut stdin_receiver) = tokio::sync::mpsc::channel::<String>(100);
let (stdout_sender, stdout_receiver) = tokio::sync::mpsc::channel::<String>(100);
if let Some(mut stdin) = child.stdin.take() {
tokio::spawn(async move {
while let Some(input) = stdin_receiver.recv().await {
if let Err(e) = stdin.write_all(input.as_bytes()).await {
error!("Failed to write to process stdin: {}", e);
break;
}
if let Err(e) = stdin.write_all(b"\n").await {
error!("Failed to write newline to process stdin: {}", e);
break;
}
if let Err(e) = stdin.flush().await {
error!("Failed to flush process stdin: {}", e);
break;
}
}
});
}
if let Some(stdout) = child.stdout.take() {
let stdout_sender_clone = stdout_sender.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
while let Ok(bytes_read) = reader.read_line(&mut line).await {
if bytes_read == 0 {
break;
}
if let Err(e) = stdout_sender_clone.send(line.trim_end().to_string()).await {
error!("Failed to send stdout message: {}", e);
break;
}
line.clear();
}
});
}
if let Some(stderr) = child.stderr.take() {
let stderr_sender = stdout_sender.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut line = String::new();
while let Ok(bytes_read) = reader.read_line(&mut line).await {
if bytes_read == 0 {
break;
}
if let Err(e) = stderr_sender.send(format!("STDERR: {}", line.trim_end())).await {
error!("Failed to send stderr message: {}", e);
break;
}
line.clear();
}
});
}
let queue_pid = pid;
tokio::spawn(async move {
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let process_pool = match PROCESS_POOL.get() {
Some(pool) => pool,
None => break,
};
let mut pool = process_pool.lock().await;
let process = match pool.get_mut(&queue_pid) {
Some(p) => p,
None => break,
};
while let Some(input) = process.input_queue.pop_front() {
if let Some(sender) = &process.sender {
if let Err(_) = sender.try_send(input.clone()) {
process.input_queue.push_front(input);
break;
}
} else {
process.input_queue.clear();
break;
}
}
drop(pool);
}
});
tokio::spawn(async move {
if let Err(e) = child.wait().await {
error!("Process {} exited with error: {}", pid, e);
} else {
debug!("Process {} exited successfully", pid);
}
if let Some(process_pool) = PROCESS_POOL.get() {
let mut pool = process_pool.lock().await;
pool.remove(&pid);
debug!("Process {} has exited and been removed from the pool", pid);
}
});
let process_pool = PROCESS_POOL.get_or_init(|| Arc::new(Mutex::new(HashMap::new())));
let mut pool = process_pool.lock().await;
pool.insert(
pid,
Self {
pid: Some(pid),
filename: self.filename.clone(),
arguments: self.arguments.clone(),
working_directory: self.working_directory.clone(),
sender: Some(stdin_sender),
receiver: Some(stdout_receiver),
input_queue: VecDeque::new(),
},
);
self.pid = Some(pid);
Ok(pid)
}
pub async fn get_process_by_pid(pid: u32) -> Option<ProcessHandle> {
let process_pool = PROCESS_POOL.get()?;
let pool = process_pool.lock().await;
if pool.contains_key(&pid) { Some(ProcessHandle { pid }) } else { None }
}
pub async fn is_process_running(&self) -> bool {
if let Some(pid) = self.pid {
if let Some(process_pool) = PROCESS_POOL.get() {
let pool = process_pool.lock().await;
return pool.contains_key(&pid);
}
}
false
}
}