use std::{
collections::VecDeque,
io,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use tokio::{
io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
process::{Child, Command},
sync::{mpsc, oneshot, Mutex as TokioMutex},
};
pub type Result<T> = std::result::Result<T, std::io::Error>;
pub trait SubprocessHandler: Send {
fn write_bytes(&mut self, input: &[u8])
-> impl std::future::Future<Output = Result<()>> + Send;
fn write(&mut self, input: &str) -> impl std::future::Future<Output = Result<()>> + Send {
async { self.write_bytes(input.as_bytes()).await }
}
fn write_line(&mut self, input: &str) -> impl std::future::Future<Output = Result<()>> + Send {
async move {
self.write_bytes(input.as_bytes()).await?;
self.write_bytes(b"\n").await?;
Ok(())
}
}
fn read_bytes(&mut self) -> impl std::future::Future<Output = Result<Vec<u8>>> + Send;
fn read_bytes_until(
&mut self,
delimiter: u8,
) -> impl std::future::Future<Output = Result<Vec<u8>>> + Send;
fn read(&mut self) -> impl std::future::Future<Output = Result<String>> + Send {
async {
let bytes = self.read_bytes().await?;
String::from_utf8(bytes)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8 sequence"))
}
}
fn read_until(
&mut self,
delimiter: u8,
) -> impl std::future::Future<Output = Result<String>> + Send {
async move {
let bytes = self.read_bytes_until(delimiter).await?;
String::from_utf8(bytes)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8 sequence"))
}
}
fn read_line(&mut self) -> impl std::future::Future<Output = Result<String>> + Send {
async {
let bytes = self.read_bytes_until(b'\n').await?;
String::from_utf8(bytes)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8 sequence"))
}
}
fn is_alive(&mut self) -> bool;
fn close_stdin(&mut self);
}
pub struct SubprocessPool {
max_size: usize,
processes: TokioMutex<VecDeque<Subprocess>>,
active_count: Arc<AtomicUsize>,
return_tx: mpsc::Sender<Subprocess>,
return_rx: TokioMutex<mpsc::Receiver<Subprocess>>,
builder: Arc<dyn Fn() -> Command + Send + Sync>,
waiters: (
mpsc::Sender<oneshot::Sender<()>>,
TokioMutex<mpsc::Receiver<oneshot::Sender<()>>>,
),
}
impl SubprocessPool {
pub async fn new(
builder: impl Fn() -> Command + Send + Sync + 'static,
max_size: usize,
) -> Result<Arc<Self>> {
let (return_tx, return_rx) = mpsc::channel(max_size);
let (waiter_tx, waiter_rx) = mpsc::channel(max_size);
let mut processes = VecDeque::with_capacity(max_size);
let builder = Arc::new(builder);
for _ in 0..max_size {
let process = Subprocess::from_builder(builder.clone())?;
processes.push_back(process);
}
Ok(Arc::new(Self {
builder,
max_size,
processes: TokioMutex::new(processes),
active_count: Arc::new(AtomicUsize::new(0)),
return_tx,
return_rx: TokioMutex::new(return_rx),
waiters: (waiter_tx, TokioMutex::new(waiter_rx)),
}))
}
pub async fn acquire(self: &Arc<Self>) -> Result<PooledProcess> {
loop {
let current = self.active_count.load(Ordering::SeqCst);
if current >= self.max_size {
let (notify_tx, notify_rx) = oneshot::channel();
if self.waiters.0.send(notify_tx).await.is_ok() {
let _ = notify_rx.await;
}
continue;
}
if self
.active_count
.compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
continue;
}
{
let mut rx = self.return_rx.lock().await;
let mut processes = self.processes.lock().await;
while let Ok(process) = rx.try_recv() {
processes.push_back(process);
}
processes.retain_mut(|p| p.is_alive());
while processes.len() < self.max_size {
match self.spawn_process(&mut processes) {
Ok(_) => {}
Err(e) => {
self.active_count.fetch_sub(1, Ordering::SeqCst);
return Err(e);
}
}
}
if let Some(mut process) = processes.pop_front() {
drop(processes);
drop(rx);
if process.is_alive() {
return Ok(PooledProcess {
process: Some(process),
return_tx: self.return_tx.clone(),
active_count: self.active_count.clone(),
pool: self.clone(),
});
}
continue;
}
drop(processes);
drop(rx);
}
self.active_count.fetch_sub(1, Ordering::SeqCst);
}
}
pub async fn count(&self) -> usize {
let mut rx = self.return_rx.lock().await;
let mut processes = self.processes.lock().await;
while let Ok(process) = rx.try_recv() {
processes.push_back(process);
}
processes.retain_mut(|p| p.is_alive());
while processes.len() < self.max_size {
match self.spawn_process(&mut processes) {
Ok(_) => {}
Err(_) => break,
}
}
processes.len()
}
fn spawn_process(&self, processes: &mut VecDeque<Subprocess>) -> Result<()> {
let process = Subprocess::from_builder(self.builder.clone())?;
processes.push_back(process);
Ok(())
}
}
pub struct PooledProcess {
process: Option<Subprocess>,
return_tx: mpsc::Sender<Subprocess>,
active_count: Arc<AtomicUsize>,
pool: Arc<SubprocessPool>,
}
impl Drop for PooledProcess {
fn drop(&mut self) {
if let Some(process) = self.process.take() {
if let Ok(()) = self.return_tx.try_send(process) {
if let Ok(mut rx) = self.pool.waiters.1.try_lock() {
if let Ok(waiter) = rx.try_recv() {
let _ = waiter.send(());
}
}
}
self.active_count.fetch_sub(1, Ordering::SeqCst);
}
}
}
impl SubprocessHandler for PooledProcess {
async fn write_bytes(&mut self, input: &[u8]) -> Result<()> {
let Some(process) = self.process.as_mut() else {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"No process available",
));
};
process.write_bytes(input).await
}
async fn read_bytes(&mut self) -> Result<Vec<u8>> {
let Some(process) = self.process.as_mut() else {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"No process available",
));
};
process.read_bytes().await
}
async fn read_bytes_until(&mut self, delimiter: u8) -> Result<Vec<u8>> {
let Some(process) = self.process.as_mut() else {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"No process available",
));
};
process.read_bytes_until(delimiter).await
}
fn is_alive(&mut self) -> bool {
self.process.as_mut().map(|x| x.is_alive()).unwrap_or(false)
}
fn close_stdin(&mut self) {
if let Some(process) = self.process.as_mut() {
process.close_stdin()
}
}
}
pub struct Subprocess {
child: Child,
stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
}
impl Subprocess {
pub fn new(builder: impl Fn() -> Command + 'static) -> Result<Self> {
let command = (builder)();
Self::from_command(command)
}
pub fn from_builder(builder: Arc<dyn Fn() -> Command>) -> Result<Self> {
let command = (builder)();
Self::from_command(command)
}
fn from_command(mut command: Command) -> Result<Self> {
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
command.stderr(std::process::Stdio::piped());
let mut child = command.spawn()?;
let stdout_reader = child.stdout.take().map(BufReader::new);
Ok(Self {
child,
stdout_reader,
})
}
pub fn is_alive(&mut self) -> bool {
match self.child.try_wait() {
Ok(None) => true,
Ok(Some(_)) => false,
Err(_) => false,
}
}
pub async fn write_bytes(&mut self, input: &[u8]) -> Result<()> {
if let Some(stdin) = self.child.stdin.as_mut() {
stdin.write_all(input).await?;
stdin.flush().await?;
if !self.is_alive() {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Process exited during write",
));
}
Ok(())
} else {
Err(io::Error::new(io::ErrorKind::BrokenPipe, "stdin is closed"))
}
}
pub async fn read_bytes(&mut self) -> Result<Vec<u8>> {
let mut buf = Vec::new();
if let Some(stdout) = self.stdout_reader.as_mut() {
let bytes_read = stdout.read_to_end(&mut buf).await?;
if bytes_read == 0 && !self.is_alive() {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Process exited during read",
));
}
Ok(buf)
} else {
Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"stdout is closed",
))
}
}
pub async fn read_bytes_until(&mut self, delimiter: u8) -> Result<Vec<u8>> {
let mut buf = Vec::new();
if let Some(stdout) = self.stdout_reader.as_mut() {
let bytes_read = stdout.read_until(delimiter, &mut buf).await?;
if bytes_read == 0 && !self.is_alive() {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Process exited during read",
));
}
Ok(buf)
} else {
Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"stdout is closed",
))
}
}
pub fn close_stdin(&mut self) {
self.child.stdin.take();
}
}
impl SubprocessHandler for Subprocess {
async fn write(&mut self, input: &str) -> Result<()> {
self.write_bytes(input.as_bytes()).await
}
async fn write_line(&mut self, input: &str) -> Result<()> {
self.write_bytes(input.as_bytes()).await?;
self.write_bytes(b"\n").await?;
Ok(())
}
async fn read(&mut self) -> Result<String> {
let bytes = self.read_bytes().await?;
String::from_utf8(bytes)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8 sequence"))
}
async fn read_line(&mut self) -> Result<String> {
let bytes = self.read_bytes_until(b'\n').await?;
String::from_utf8(bytes)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8 sequence"))
}
async fn write_bytes(&mut self, input: &[u8]) -> Result<()> {
self.write_bytes(input).await
}
async fn read_bytes(&mut self) -> Result<Vec<u8>> {
self.read_bytes().await
}
async fn read_bytes_until(&mut self, delimiter: u8) -> Result<Vec<u8>> {
self.read_bytes_until(delimiter).await
}
fn is_alive(&mut self) -> bool {
self.is_alive()
}
fn close_stdin(&mut self) {
self.close_stdin()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn get_echo_binary() -> Command {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("target");
path.push("debug");
path.push("examples");
path.push("echo");
Command::new(path)
}
#[tokio::test]
async fn test_pool_initialization() {
let pool = SubprocessPool::new(get_echo_binary, 2).await.unwrap();
assert_eq!(pool.count().await, 2);
}
#[tokio::test]
async fn test_echo_interactive() {
let pool = SubprocessPool::new(get_echo_binary, 1).await.unwrap();
let mut process = pool.acquire().await.unwrap();
process.write_line("hello").await.unwrap();
let output = process.read_line().await.unwrap();
assert_eq!(output, ">>hello\n");
process.write_line("world").await.unwrap();
let output = process.read_line().await.unwrap();
assert_eq!(output, ">>world\n");
}
#[tokio::test]
async fn test_write_line() {
let pool = SubprocessPool::new(get_echo_binary, 1).await.unwrap();
let mut process = pool.acquire().await.unwrap();
process.write_line("test").await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response, ">>test\n");
}
#[tokio::test]
async fn test_echo_io_error() {
let pool = SubprocessPool::new(get_echo_binary, 1).await.unwrap();
let mut process = pool.acquire().await.unwrap();
process.write_line("/io").await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response, "IO error incoming\n");
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let result = process.write_line("test").await;
assert!(matches!(result, Err(_)));
}
#[tokio::test]
async fn test_echo_exit_code() {
let pool = SubprocessPool::new(get_echo_binary, 1).await.unwrap();
let mut process = pool.acquire().await.unwrap();
process.write_line("/exit 42").await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response, "Exiting with code 42\n");
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let result = process.write_line("test").await;
assert!(matches!(result, Err(_)));
}
#[tokio::test]
async fn test_echo_invalid_utf8() {
let pool = SubprocessPool::new(get_echo_binary, 1).await.unwrap();
let mut process = pool.acquire().await.unwrap();
process.write_line("/invalid-utf8").await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
let result = process.read_line().await;
assert!(matches!(result, Err(_)));
}
#[tokio::test]
async fn test_pool_exhaustion_and_recreation() {
let pool = SubprocessPool::new(get_echo_binary, 2).await.unwrap();
let initial_count = pool.count().await;
assert_eq!(initial_count, 2, "Pool should start with max size");
let mut process1 = pool.acquire().await.unwrap();
let mut process2 = pool.acquire().await.unwrap();
process1.write_line("/exit 1").await.unwrap();
process2.write_line("/exit 1").await.unwrap();
drop(process1);
drop(process2);
for i in 1..=20 {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let count = pool.count().await;
println!("After {}ms: pool has {} processes", i * 500, count);
if count == 2 {
break;
}
}
let final_count = pool.count().await;
assert_eq!(final_count, 2, "Pool should be back at max size");
let mut process = pool.acquire().await.unwrap();
process.write_line("test").await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response, ">>test\n");
}
#[tokio::test]
async fn test_pool_all_processes_dead() {
println!("Starting test_pool_all_processes_dead");
let pool = SubprocessPool::new(get_echo_binary, 2).await.unwrap();
println!("Pool created with 2 processes");
let mut process1 = pool.acquire().await.unwrap();
let mut process2 = pool.acquire().await.unwrap();
println!("Acquired both processes");
process1.write_line("/exit 1").await.unwrap();
process2.write_line("/exit 1").await.unwrap();
println!("Sent exit commands to both processes");
drop(process1);
drop(process2);
println!("Dropped processes");
for i in 1..=20 {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let count = pool.count().await;
println!("After {}ms: pool has {} processes", i * 500, count);
if count == 2 {
break;
}
}
println!("Checking pool size");
let final_count = pool.count().await;
assert_eq!(final_count, 2, "Pool should recreate all processes");
println!("Testing final process");
let mut process = pool.acquire().await.unwrap();
println!("Got final process");
process.write_line("final test").await.unwrap();
println!("Wrote to final process");
let response = process.read_line().await.unwrap();
println!("Got final response: {}", response);
assert_eq!(response, ">>final test\n");
println!("Test complete");
}
#[tokio::test]
async fn test_pool_multiple_acquires() {
println!("Starting test");
let pool = Arc::new(SubprocessPool::new(get_echo_binary, 5).await.unwrap());
let concurrent_tasks = Arc::new(AtomicUsize::new(0));
println!("Pool created");
let mut handles = Vec::new();
for i in 0..10 {
let pool = pool.clone();
let concurrent_tasks = concurrent_tasks.clone();
handles.push(tokio::spawn(async move {
println!("Task {} starting", i);
println!("Task {} acquiring process", i);
let mut process = pool.acquire().await.unwrap();
let prev_count = concurrent_tasks.fetch_add(1, Ordering::SeqCst);
println!("Task {} got count {}", i, prev_count);
assert!(
prev_count < 5,
"Should never have more than 5 concurrent tasks"
);
println!("Task {} got process", i);
println!("Task {} writing", i);
process.write_line(&format!("Task {}", i)).await.unwrap();
println!("Task {} reading", i);
let response = process.read_line().await.unwrap();
println!("Task {} got response: {}", i, response);
assert_eq!(response, format!(">>Task {}\n", i));
println!("Task {} releasing process", i);
concurrent_tasks.fetch_sub(1, Ordering::SeqCst);
drop(process);
println!("Task {} done", i);
}));
}
println!("Waiting for tasks to complete");
for handle in handles {
handle.await.unwrap();
}
println!("All tasks complete");
assert_eq!(
concurrent_tasks.load(Ordering::SeqCst),
0,
"All tasks should be complete"
);
}
#[tokio::test]
async fn test_pool_spawn_failure() {
let result = SubprocessPool::new(|| Command::new("\0"), 2).await;
assert!(matches!(result, Err(_)));
}
#[tokio::test]
async fn test_pool_io_errors() {
let pool = SubprocessPool::new(|| Command::new("false"), 1)
.await
.unwrap();
let mut process = pool.acquire().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let write_result = process.write_line("test").await;
assert!(matches!(write_result, Err(_)));
let read_result = process.read_line().await;
assert!(matches!(read_result, Err(_)));
}
#[tokio::test]
async fn test_pool_high_contention() {
let pool = SubprocessPool::new(get_echo_binary, 4).await.unwrap();
let tasks: Vec<_> = (0..100)
.map(|i| {
let pool = pool.clone();
tokio::spawn(async move {
let mut process = pool.acquire().await.unwrap();
process.write_line(&format!("task_{}", i)).await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response.trim(), format!(">>task_{}", i));
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
})
})
.collect();
for task in tasks {
task.await.unwrap();
}
let mut process = pool.acquire().await.unwrap();
process.write_line("final_test").await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response.trim(), ">>final_test");
}
#[tokio::test]
async fn test_subprocess_direct() {
let mut process = Subprocess::new(get_echo_binary).unwrap();
process.write_line("hello world").await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response.trim(), ">>hello world");
process.write_line("line 1").await.unwrap();
process.write_line("line 2").await.unwrap();
let resp1 = process.read_line().await.unwrap();
let resp2 = process.read_line().await.unwrap();
assert_eq!(resp1.trim(), ">>line 1");
assert_eq!(resp2.trim(), ">>line 2");
process.write_bytes(b"raw bytes\n").await.unwrap();
let response = process.read_line().await.unwrap();
assert_eq!(response.trim(), ">>raw bytes");
process.write_bytes(b"part1:part2\n").await.unwrap();
let response = process.read_bytes_until(b':').await.unwrap();
assert_eq!(&response, b">>part1:");
let response = process.read_line().await.unwrap();
assert_eq!(response.trim(), "part2");
assert!(process.is_alive());
process.close_stdin();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
assert!(!process.is_alive());
}
}