use std::panic;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{mpsc, Arc};
use crate::blockifier::config::WorkerPoolConfig;
use crate::concurrency::utils::AbortIfPanic;
use crate::concurrency::worker_logic::WorkerExecutor;
use crate::concurrency::TxIndex;
use crate::state::state_api::StateReader;
#[derive(Debug)]
pub struct WorkerPool<S: StateReader> {
senders: Vec<mpsc::Sender<Option<Arc<WorkerExecutor<S>>>>>,
handlers: Vec<std::thread::JoinHandle<()>>,
a_thread_panicked: Arc<AtomicBool>,
}
impl<S: StateReader + Send + 'static> WorkerPool<S> {
pub fn start(config: &WorkerPoolConfig) -> Self {
let a_thread_panicked = Arc::new(AtomicBool::new(false));
let mut senders = Vec::<mpsc::Sender<Option<Arc<WorkerExecutor<S>>>>>::new();
let mut receivers = Vec::<mpsc::Receiver<Option<Arc<WorkerExecutor<S>>>>>::new();
for _ in 0..config.n_workers {
let (sender, receiver) = mpsc::channel();
senders.push(sender);
receivers.push(receiver);
}
let stack_size = config.stack_size;
let handlers = receivers
.into_iter()
.enumerate()
.map(|(thread_id, receiver)| {
let mut thread_builder = std::thread::Builder::new();
thread_builder = thread_builder.stack_size(stack_size);
let worker_thread = WorkerThread {
a_thread_panicked: a_thread_panicked.clone(),
receiver,
thread_id,
};
thread_builder
.spawn(move || worker_thread.run_thread())
.expect("Failed to spawn thread.")
})
.collect();
WorkerPool { senders, handlers, a_thread_panicked }
}
pub fn run(&self, worker_executor: Arc<WorkerExecutor<S>>) {
for sender in self.senders.iter() {
sender.send(Some(worker_executor.clone())).expect("Failed to send worker executor.");
}
}
pub fn run_and_wait(&self, worker_executor: Arc<WorkerExecutor<S>>, target_n_txs: TxIndex) {
self.run(worker_executor.clone());
worker_executor.scheduler.wait_for_completion(target_n_txs);
worker_executor.scheduler.halt();
self.check_panic();
}
pub fn check_panic(&self) {
if self.a_thread_panicked.load(Ordering::Acquire) {
panic!("One of the threads panicked.");
}
}
pub fn join(self) {
for sender in self.senders {
sender.send(None).expect("Failed to signal worker thread to stop.");
}
for handler in self.handlers {
handler.join().expect("Failed to join thread.");
}
}
}
struct WorkerThread<S: StateReader> {
a_thread_panicked: Arc<AtomicBool>,
receiver: mpsc::Receiver<Option<Arc<WorkerExecutor<S>>>>,
thread_id: usize,
}
impl<S: StateReader> WorkerThread<S> {
fn run_thread(&self) {
let mut i = 0;
while let Some(worker_executor) =
self.receiver.recv().expect("Failed to receive worker executor.")
{
let block_number = worker_executor.block_context.block_info.block_number;
log::debug!(
"Worker pool (thread {}) starting worker #{} (block number {block_number})",
self.thread_id,
i,
);
self._run_executor(&*worker_executor);
log::debug!(
"Worker pool (thread {}) worker done #{} (block number {block_number})",
self.thread_id,
i,
);
i += 1;
}
}
fn _run_executor(&self, worker_executor: &WorkerExecutor<S>) {
if self.a_thread_panicked.load(Ordering::Acquire) {
panic!("Another thread panicked. Aborting.");
}
let abort_guard = AbortIfPanic;
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
worker_executor.run();
}));
if let Err(err) = res {
self.a_thread_panicked.store(true, Ordering::Release);
worker_executor.scheduler.halt();
abort_guard.release();
panic::resume_unwind(err);
}
abort_guard.release();
}
}