use std::ffi::OsStr;
use std::fmt;
use std::io;
use std::process;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{mpsc, Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;
use ipc_channel::ipc;
use serde::{de::DeserializeOwned, Serialize};
use crate::core::MarshalledCall;
use crate::error::SpawnError;
use crate::proc::{Builder, JoinHandle, JoinHandleInner, ProcCommon, ProcessHandleState};
use crate::serde::with_ipc_mode;
type WaitFunc = Box<dyn FnOnce() -> bool + Send>;
type NotifyErrorFunc = Box<dyn FnMut(SpawnError) + Send>;
#[derive(Debug)]
pub struct PooledHandleState {
pub cancelled: AtomicBool,
pub process_handle_state: Mutex<Option<Arc<ProcessHandleState>>>,
}
impl PooledHandleState {
pub fn kill(&self) {
self.cancelled.store(true, Ordering::SeqCst);
if let Some(ref process_handle_state) = *self.process_handle_state.lock().unwrap() {
process_handle_state.kill();
}
}
}
pub struct PooledHandle<T> {
waiter_rx: mpsc::Receiver<Result<T, SpawnError>>,
shared: Arc<PooledHandleState>,
}
impl<T> PooledHandle<T> {
pub fn process_handle_state(&self) -> Option<Arc<ProcessHandleState>> {
self.shared.process_handle_state.lock().unwrap().clone()
}
pub fn kill(&mut self) -> Result<(), SpawnError> {
self.shared.kill();
Ok(())
}
}
impl<T: Serialize + DeserializeOwned> PooledHandle<T> {
pub fn join(&mut self) -> Result<T, SpawnError> {
match self.waiter_rx.recv() {
Ok(Ok(rv)) => Ok(rv),
Ok(Err(err)) => Err(err),
Err(..) => Err(SpawnError::new_remote_close()),
}
}
pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, SpawnError> {
match self.waiter_rx.recv_timeout(timeout) {
Ok(Ok(rv)) => Ok(rv),
Ok(Err(err)) => Err(err),
Err(mpsc::RecvTimeoutError::Timeout) => {
self.kill().ok();
Err(SpawnError::new_timeout())
}
Err(mpsc::RecvTimeoutError::Disconnected) => Err(SpawnError::new_remote_close()),
}
}
}
type PoolSender = mpsc::Sender<(
MarshalledCall,
Arc<PooledHandleState>,
WaitFunc,
NotifyErrorFunc,
)>;
pub struct Pool {
sender: Mutex<PoolSender>,
shared: Arc<PoolShared>,
}
impl fmt::Debug for Pool {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Pool")
.field("size", &self.size())
.field("active_count", &self.active_count())
.field("queued_count", &self.queued_count())
.finish()
}
}
impl Pool {
pub fn new(size: usize) -> Result<Pool, SpawnError> {
Pool::builder(size).build()
}
pub fn builder(size: usize) -> PoolBuilder {
PoolBuilder::new(size)
}
pub fn size(&self) -> usize {
self.shared.monitors.lock().unwrap().len()
}
pub fn queued_count(&self) -> usize {
self.shared.queued_count.load(Ordering::Relaxed)
}
pub fn active_count(&self) -> usize {
self.shared.active_count.load(Ordering::SeqCst)
}
pub fn spawn<
A: Serialize + DeserializeOwned,
R: Serialize + DeserializeOwned + Send + 'static,
>(
&self,
args: A,
func: fn(A) -> R,
) -> JoinHandle<R> {
self.assert_alive();
let (args_tx, args_rx) = ipc::channel().unwrap();
let (return_tx, return_rx) = ipc::channel().unwrap();
let call = MarshalledCall::marshal::<A, R>(func, args_rx, return_tx);
let (waiter_tx, waiter_rx) = mpsc::sync_channel(0);
let error_waiter_tx = waiter_tx.clone();
self.shared.queued_count.fetch_add(1, Ordering::SeqCst);
let shared = Arc::new(PooledHandleState {
cancelled: AtomicBool::new(false),
process_handle_state: Mutex::new(None),
});
self.sender
.lock()
.expect("pool sender poisoned")
.send((
call,
shared.clone(),
Box::new(move || {
with_ipc_mode(|| {
if let Ok(rv) = return_rx.recv() {
waiter_tx.send(rv.map_err(Into::into)).is_ok()
} else {
false
}
})
}),
Box::new(move |error| {
error_waiter_tx.send(Err(error)).ok();
}),
))
.ok();
args_tx.send(args).unwrap();
JoinHandle {
inner: Ok(JoinHandleInner::Pooled(PooledHandle { waiter_rx, shared })),
}
}
pub fn join(&self) {
self.assert_alive();
if !self.shared.has_work() {
return;
}
let generation = self.shared.join_generation.load(Ordering::SeqCst);
let mut lock = self.shared.empty_trigger.lock().unwrap();
while generation == self.shared.join_generation.load(Ordering::Relaxed)
&& self.shared.has_work()
{
lock = self.shared.empty_condvar.wait(lock).unwrap();
}
self.shared
.join_generation
.compare_exchange(
generation,
generation.wrapping_add(1),
Ordering::SeqCst,
Ordering::SeqCst,
)
.ok();
}
pub fn shutdown(&self) {
if !self.shared.dead.load(Ordering::SeqCst) {
self.join();
self.kill();
}
}
pub fn kill(&self) {
if self.shared.dead.load(Ordering::SeqCst) {
return;
}
self.shared.dead.store(true, Ordering::SeqCst);
for monitor in self.shared.monitors.lock().unwrap().iter_mut() {
if let Some(mut join_handle) = monitor.join_handle.lock().unwrap().take() {
join_handle.kill().ok();
}
}
}
fn assert_alive(&self) {
if self.shared.dead.load(Ordering::SeqCst) {
panic!("The process pool is dead");
}
}
}
#[derive(Debug)]
pub struct PoolBuilder {
size: usize,
disable_stdin: bool,
disable_stdout: bool,
disable_stderr: bool,
common: ProcCommon,
}
impl PoolBuilder {
fn new(size: usize) -> PoolBuilder {
PoolBuilder {
size,
disable_stdin: false,
disable_stdout: false,
disable_stderr: false,
common: ProcCommon::default(),
}
}
define_common_methods!();
pub fn disable_stdin(&mut self) -> &mut Self {
self.disable_stdin = true;
self
}
pub fn disable_stdout(&mut self) -> &mut Self {
self.disable_stdout = true;
self
}
pub fn disable_stderr(&mut self) -> &mut Self {
self.disable_stderr = true;
self
}
pub fn build(&mut self) -> Result<Pool, SpawnError> {
let (tx, rx) = mpsc::channel();
let shared = Arc::new(PoolShared {
call_receiver: Mutex::new(rx),
empty_trigger: Mutex::new(()),
empty_condvar: Condvar::new(),
join_generation: AtomicUsize::new(0),
monitors: Mutex::new(Vec::with_capacity(self.size)),
queued_count: AtomicUsize::new(0),
active_count: AtomicUsize::new(0),
dead: AtomicBool::new(false),
});
{
let mut monitors = shared.monitors.lock().unwrap();
for _ in 0..self.size {
monitors.push(spawn_worker(shared.clone(), self)?);
}
}
Ok(Pool {
sender: Mutex::new(tx),
shared,
})
}
}
impl Drop for Pool {
fn drop(&mut self) {
self.kill();
}
}
struct PoolShared {
#[allow(clippy::type_complexity)]
call_receiver: Mutex<
mpsc::Receiver<(
MarshalledCall,
Arc<PooledHandleState>,
WaitFunc,
NotifyErrorFunc,
)>,
>,
empty_trigger: Mutex<()>,
empty_condvar: Condvar,
join_generation: AtomicUsize,
monitors: Mutex<Vec<WorkerMonitor>>,
queued_count: AtomicUsize,
active_count: AtomicUsize,
dead: AtomicBool,
}
impl PoolShared {
fn has_work(&self) -> bool {
self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
}
fn no_work_notify_all(&self) {
if !self.has_work() {
drop(
self.empty_trigger
.lock()
.expect("Unable to notify all joining threads"),
);
self.empty_condvar.notify_all();
}
}
}
struct WorkerMonitor {
join_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
}
fn spawn_worker(
shared: Arc<PoolShared>,
builder: &PoolBuilder,
) -> Result<WorkerMonitor, SpawnError> {
let join_handle = Arc::new(Mutex::new(None::<JoinHandle<()>>));
let current_call_tx = Arc::new(Mutex::new(None::<ipc::IpcSender<MarshalledCall>>));
let spawn = Arc::new(Mutex::new({
let disable_stdin = builder.disable_stdin;
let disable_stdout = builder.disable_stdout;
let disable_stderr = builder.disable_stderr;
let common = builder.common.clone();
let join_handle = join_handle.clone();
let current_call_tx = current_call_tx.clone();
move || {
let (call_tx, call_rx) = ipc::channel::<MarshalledCall>().unwrap();
let mut builder = Builder::new();
builder.common(common.clone());
if disable_stdin {
builder.stdin(process::Stdio::null());
}
if disable_stdout {
builder.stdout(process::Stdio::null());
}
if disable_stderr {
builder.stderr(process::Stdio::null());
}
*join_handle.lock().unwrap() = Some(builder.spawn(call_rx, |rx| {
while let Ok(call) = rx.recv() {
call.call(false);
}
}));
*current_call_tx.lock().unwrap() = Some(call_tx);
}
}));
let check_for_restart = {
let spawn = spawn.clone();
let join_handle = join_handle.clone();
let shared = shared.clone();
move |f: &mut NotifyErrorFunc| {
if let Some(join_handle) = join_handle.lock().unwrap().take() {
match join_handle.join() {
Ok(()) => f(SpawnError::from(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"client process died",
))),
Err(err) => f(err),
}
}
if !shared.dead.load(Ordering::SeqCst) {
(*spawn.lock().unwrap())();
}
}
};
{
let join_handle = join_handle.clone();
thread::Builder::new()
.name("procspawn-monitor".into())
.spawn(move || {
loop {
if shared.dead.load(Ordering::SeqCst) {
break;
}
let (call, state, wait_func, mut err_func) = {
let lock = shared
.call_receiver
.lock()
.expect("Monitor thread unable to lock call receiver");
match lock.recv() {
Ok(rv) => rv,
Err(_) => break,
}
};
shared.active_count.fetch_add(1, Ordering::SeqCst);
shared.queued_count.fetch_sub(1, Ordering::SeqCst);
if state.cancelled.load(Ordering::SeqCst) {
err_func(SpawnError::new_cancelled());
} else {
if let Some(ref mut handle) = *join_handle.lock().unwrap() {
*state.process_handle_state.lock().unwrap() =
handle.process_handle_state();
}
let mut restart = false;
{
let mut call_tx = current_call_tx.lock().unwrap();
if let Some(ref mut call_tx) = *call_tx {
match with_ipc_mode(|| call_tx.send(call)) {
Ok(()) => {}
Err(..) => {
restart = true;
}
}
} else {
restart = true;
}
}
if !restart && !wait_func() {
restart = true;
}
*state.process_handle_state.lock().unwrap() = None;
if restart {
check_for_restart(&mut err_func);
}
}
shared.active_count.fetch_sub(1, Ordering::SeqCst);
shared.no_work_notify_all();
}
})
.unwrap();
}
(*spawn.lock().unwrap())();
Ok(WorkerMonitor { join_handle })
}