use crate::error::PayloadError;
use crate::project::error::ProjectError;
use crossbeam::channel::{bounded, unbounded, Receiver, SendError, Sender, TryRecvError};
use crossbeam::deque::{Injector, Steal, Stealer, Worker};
use std::any::Any;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::{io, panic, thread};
use uuid::Uuid;
pub struct WorkToken {
pub on_start: Box<dyn Fn() + Send + 'static>,
pub on_complete: Box<dyn Fn() + Send + 'static>,
pub work: Box<dyn FnOnce() + Send + 'static>,
}
impl WorkToken {
fn new(
on_start: Box<dyn Fn() + Send + 'static>,
on_complete: Box<dyn Fn() + Send + 'static>,
work: Box<dyn FnOnce() + Send + 'static>,
) -> Self {
Self {
on_start,
on_complete,
work,
}
}
}
pub trait ToWorkToken: Send + 'static {
fn on_start(&self) -> Box<dyn Fn() + Send + Sync> {
Box::new(|| {})
}
fn on_complete(&self) -> Box<dyn Fn() + Send + Sync> {
Box::new(|| {})
}
fn work(self);
}
impl<T: ToWorkToken> From<T> for WorkToken {
fn from(tok: T) -> Self {
let on_start = tok.on_start();
let on_complete = tok.on_complete();
WorkTokenBuilder::new(|| tok.work())
.on_start(on_start)
.on_complete(on_complete)
.build()
}
}
impl<F: FnOnce() + Send + 'static> ToWorkToken for F {
fn work(self) {
(self)()
}
}
fn empty() {}
pub struct WorkTokenBuilder<W, S, C>
where
W: FnOnce(),
{
on_start: S,
on_complete: C,
work: W,
}
impl<W, S, C> WorkTokenBuilder<W, S, C>
where
W: FnOnce() + Send + 'static,
S: Fn() + Send + 'static,
C: Fn() + Send + 'static,
{
pub fn build(self) -> WorkToken {
WorkToken::new(
Box::new(self.on_start),
Box::new(self.on_complete),
Box::new(self.work),
)
}
}
impl<W> WorkTokenBuilder<W, fn(), fn()>
where
W: FnOnce() + Send + 'static,
{
pub fn new(work: W) -> Self {
Self {
on_start: empty,
on_complete: empty,
work,
}
}
}
impl<W, S1, C> WorkTokenBuilder<W, S1, C>
where
W: FnOnce(),
{
pub fn on_start<S2: Fn() + Send + 'static>(self, on_start: S2) -> WorkTokenBuilder<W, S2, C> {
WorkTokenBuilder {
on_start,
on_complete: self.on_complete,
work: self.work,
}
}
}
impl<W, S, C1> WorkTokenBuilder<W, S, C1>
where
W: FnOnce(),
{
pub fn on_complete<C2: Fn() + Send + 'static>(
self,
on_complete: C2,
) -> WorkTokenBuilder<W, S, C2> {
WorkTokenBuilder {
on_complete,
on_start: self.on_start,
work: self.work,
}
}
}
enum WorkerQueueRequest {
GetStatus,
}
enum WorkerQueueResponse {
Status(HashMap<Uuid, WorkerStatus>),
}
#[derive(Debug, Eq, PartialEq)]
enum WorkerMessage {
Stop,
}
type WorkTokenId = u64;
pub struct WorkerExecutor {
max_jobs: usize,
injector: Arc<Injector<WorkerTuple>>,
connection: Option<Connection>,
}
struct Connection {
join_send: Sender<()>,
inner_handle: JoinHandle<()>,
request_sender: Sender<WorkerQueueRequest>,
response_receiver: Receiver<WorkerQueueResponse>,
}
impl Connection {
fn handle_request(&self, request: WorkerQueueRequest) -> WorkerQueueResponse {
self.request_sender.send(request).unwrap();
self.response_receiver.recv().unwrap()
}
}
impl Drop for WorkerExecutor {
fn drop(&mut self) {
self.join_inner();
}
}
impl WorkerExecutor {
pub fn new(pool_size: usize) -> io::Result<Self> {
let mut out = Self {
max_jobs: pool_size,
injector: Arc::new(Injector::new()),
connection: None,
};
out.start()?;
Ok(out)
}
fn start(&mut self) -> io::Result<()> {
self.connection = Some(Inner::start(&self.injector, self.max_jobs)?);
Ok(())
}
pub fn join(mut self) -> Result<(), PayloadError<ProjectError>> {
self.finish_jobs().map_err(PayloadError::new)?;
self.join_inner().map_err(PayloadError::new)?;
Ok(())
}
fn join_inner(&mut self) -> thread::Result<()> {
if let Some(connection) = std::mem::replace(&mut self.connection, None) {
let _ = connection.join_send.send(());
connection.inner_handle.join()?;
};
Ok(())
}
pub fn submit<I: Into<WorkToken>>(&self, token: I) -> io::Result<WorkHandle> {
let work_token = token.into();
let (handle, channel) = work_channel(self);
let id = rand::random();
let work_tuple = WorkerTuple(id, work_token, channel);
self.injector.push(work_tuple);
Ok(handle)
}
pub fn any_panicked(&self) -> bool {
let status = self
.connection
.as_ref()
.map(|s| s.handle_request(WorkerQueueRequest::GetStatus));
match status {
Some(WorkerQueueResponse::Status(status)) => {
status.values().any(|s| s == &WorkerStatus::Panic)
}
None => false,
}
}
pub fn finish_jobs(&mut self) -> io::Result<()> {
if self.connection.is_none() {
panic!("Shouldn't be possible")
}
loop {
if self.injector.is_empty() {
break;
}
}
while let Some(connection) = &self.connection {
let status = connection.handle_request(WorkerQueueRequest::GetStatus);
let finished = match status {
WorkerQueueResponse::Status(s) => s
.values()
.all(|status| status == &WorkerStatus::Idle || status == &WorkerStatus::Panic),
};
if finished {
break;
}
}
Ok(())
}
pub fn queue(&self) -> WorkerQueue {
WorkerQueue::new(self)
}
}
struct Inner {
max_jobs: usize,
injector: Arc<Injector<WorkerTuple>>,
worker: Worker<WorkerTuple>,
message_sender: Sender<WorkerMessage>,
status_receiver: Receiver<WorkStatusUpdate>,
stop_receiver: Receiver<()>,
handles: Vec<JoinHandle<()>>,
id_to_status: HashMap<Uuid, WorkerStatus>,
request_recv: Receiver<WorkerQueueRequest>,
response_sndr: Sender<WorkerQueueResponse>,
}
#[derive(Clone)]
pub struct WorkHandle<'exec> {
recv: Receiver<()>,
owner: &'exec WorkerExecutor,
}
fn work_channel(exec: &WorkerExecutor) -> (WorkHandle, Sender<()>) {
let (s, r) = bounded::<()>(1);
(
WorkHandle {
recv: r,
owner: exec,
},
s,
)
}
impl WorkHandle<'_> {
pub fn join(self) -> thread::Result<()> {
self.recv
.recv()
.map_err(|b| Box::new(b) as Box<dyn Any + Send>)
}
}
mod inner_impl {
use super::*;
impl Inner {
fn new(
injector: &Arc<Injector<WorkerTuple>>,
pool_size: usize,
stop_recv: Receiver<()>,
) -> io::Result<(
Self,
Sender<WorkerQueueRequest>,
Receiver<WorkerQueueResponse>,
)> {
let (s, r) = unbounded();
let (s2, r2) = unbounded();
let requests = unbounded();
let responses = unbounded();
let mut output = Self {
max_jobs: pool_size,
injector: injector.clone(),
worker: Worker::new_fifo(),
message_sender: s,
status_receiver: r2,
stop_receiver: stop_recv,
handles: vec![],
id_to_status: HashMap::new(),
request_recv: requests.1,
response_sndr: responses.0,
};
for _ in 0..pool_size {
let stealer = output.worker.stealer();
let (id, handle) = AssembleWorker::new(stealer, r.clone(), s2.clone()).start()?;
output.id_to_status.insert(id, WorkerStatus::Unknown);
output.handles.push(handle);
}
Ok((output, requests.0, responses.1))
}
pub fn start(
injector: &Arc<Injector<WorkerTuple>>,
pool_size: usize,
) -> io::Result<Connection> {
let (stop_s, stop_r) = unbounded();
let (inner, sender, recv) = Self::new(injector, pool_size, stop_r)?;
let handle = thread::spawn(move || inner.run());
Ok(Connection {
join_send: stop_s,
inner_handle: handle,
request_sender: sender,
response_receiver: recv,
})
}
fn run(mut self) {
loop {
match self.stop_receiver.try_recv() {
Ok(()) => break,
Err(TryRecvError::Empty) => {}
Err(_) => break,
}
let _ = self.injector.steal_batch(&self.worker);
self.update_worker_status();
self.handle_requests();
}
for _ in &self.handles {
self.message_sender.send(WorkerMessage::Stop);
}
for handle in self.handles {
handle.join();
}
}
fn update_worker_status(&mut self) {
while let Ok(status) = self.status_receiver.try_recv() {
self.id_to_status.insert(status.worker_id, status.status);
}
}
fn handle_requests(&mut self) {
while let Ok(req) = self.request_recv.try_recv() {
let response = self.on_request(req);
self.response_sndr
.send(response)
.expect("Inner still exists while Outer gone")
}
}
fn on_request(&mut self, request: WorkerQueueRequest) -> WorkerQueueResponse {
match request {
WorkerQueueRequest::GetStatus => {
let map = self.id_to_status.clone();
WorkerQueueResponse::Status(map)
}
}
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
enum WorkerStatus {
Unknown,
TaskRunning(WorkTokenId),
Idle,
Panic,
}
struct WorkStatusUpdate {
worker_id: Uuid,
status: WorkerStatus,
}
struct AssembleWorker {
id: Uuid,
stealer: Stealer<WorkerTuple>,
message_recv: Receiver<WorkerMessage>,
status_send: Sender<WorkStatusUpdate>,
}
impl Drop for AssembleWorker {
fn drop(&mut self) {
if thread::panicking() {
self.report_status(WorkerStatus::Panic).unwrap()
}
}
}
impl AssembleWorker {
pub fn new(
stealer: Stealer<WorkerTuple>,
message_recv: Receiver<WorkerMessage>,
status_send: Sender<WorkStatusUpdate>,
) -> Self {
let id = Uuid::new_v4();
Self {
id,
stealer,
message_recv,
status_send,
}
}
fn start(mut self) -> io::Result<(Uuid, JoinHandle<()>)> {
let id = self.id;
self.report_status(WorkerStatus::Idle).unwrap();
let handle = thread::Builder::new()
.name(format!("Assemble Worker (id = {})", id))
.spawn(move || self.run())?;
Ok((id, handle))
}
fn run(&mut self) {
'outer: loop {
match self.message_recv.try_recv() {
Ok(msg) => match msg {
WorkerMessage::Stop => break 'outer,
},
Err(TryRecvError::Empty) => {}
Err(_) => break 'outer,
}
if let Steal::Success(tuple) = self.stealer.steal() {
let WorkerTuple(id, work, vc) = tuple;
self.report_status(WorkerStatus::TaskRunning(id)).unwrap();
(work.on_start)();
(work.work)();
(work.on_complete)();
self.report_status(WorkerStatus::Idle).unwrap();
match vc.send(()) {
Ok(()) => {}
Err(_e) => {
}
}
}
}
}
fn report_status(&mut self, status: WorkerStatus) -> Result<(), SendError<WorkStatusUpdate>> {
self.status_send.send(WorkStatusUpdate {
worker_id: self.id,
status,
})
}
}
struct WorkerTuple(WorkTokenId, WorkToken, Sender<()>);
pub struct WorkerQueue<'exec> {
executor: &'exec WorkerExecutor,
handles: Vec<WorkHandle<'exec>>,
}
impl<'exec> Drop for WorkerQueue<'exec> {
fn drop(&mut self) {
let handles = self.handles.drain(..);
for handle in handles {
let _ = handle.join();
}
}
}
impl<'exec> WorkerQueue<'exec> {
pub fn new(executor: &'exec WorkerExecutor) -> Self {
Self {
executor,
handles: vec![],
}
}
pub fn submit<W: Into<WorkToken>>(&mut self, work: W) -> io::Result<WorkHandle> {
let handle = self.executor.submit(work)?;
self.handles.push(handle.clone());
Ok(handle)
}
pub fn join(mut self) -> thread::Result<()> {
for handle in self.handles.drain(..) {
handle.join()?;
}
Ok(())
}
pub fn typed<W: Into<WorkToken>>(self) -> TypedWorkerQueue<'exec, W> {
TypedWorkerQueue {
_data: PhantomData,
queue: self,
}
}
}
pub struct TypedWorkerQueue<'exec, W: Into<WorkToken>> {
_data: PhantomData<W>,
queue: WorkerQueue<'exec>,
}
impl<'exec, W: Into<WorkToken>> TypedWorkerQueue<'exec, W> {
pub fn new(executor: &'exec WorkerExecutor) -> Self {
Self {
_data: PhantomData,
queue: executor.queue(),
}
}
pub fn submit(&mut self, work: W) -> io::Result<WorkHandle> {
self.queue.submit(work)
}
pub fn join(self) -> thread::Result<()> {
self.queue.join()
}
}
#[cfg(test)]
mod tests {
use crate::work_queue::WorkerExecutor;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::Duration;
const WORK_SIZE: usize = 6;
#[test]
#[ignore]
fn parallelism_works() {
let mut worker_queue = WorkerExecutor::new(WORK_SIZE).unwrap();
let _wait_group = Arc::new(Barrier::new(WORK_SIZE));
let add_all = Arc::new(AtomicUsize::new(0));
let mut current_worker = 0;
for _ in 0..(WORK_SIZE * 2) {
let add_all = add_all.clone();
let this_worker = current_worker;
current_worker += 1;
worker_queue
.submit(move || {
debug!("running worker thread {}", this_worker);
add_all.fetch_add(1, Ordering::SeqCst);
})
.unwrap();
}
worker_queue.finish_jobs().unwrap();
assert_eq!(add_all.load(Ordering::SeqCst), WORK_SIZE * 2);
for _ in 0..(WORK_SIZE * 2) {
let add_all = add_all.clone();
let this_worker = current_worker;
current_worker += 1;
worker_queue
.submit(move || {
debug!("running worker thread {}", this_worker);
add_all.fetch_add(1, Ordering::SeqCst);
})
.unwrap();
}
worker_queue.join().unwrap();
assert_eq!(add_all.load(Ordering::SeqCst), WORK_SIZE * 4);
}
#[test]
fn worker_queues_provide_protection() {
let exec = WorkerExecutor::new(WORK_SIZE).unwrap();
let accum = Arc::new(AtomicUsize::new(0));
{
let mut queue = exec.queue();
for _i in 0..64 {
let accum = accum.clone();
queue
.submit(move || {
accum.fetch_add(1, Ordering::Relaxed);
})
.unwrap();
}
}
assert_eq!(accum.load(Ordering::Acquire), 64);
}
fn test_executor_pool_size_ensured(pool_size: usize) {
let workers_running = Arc::new(AtomicUsize::new(0));
let max_workers_running = Arc::new(AtomicUsize::new(0));
let executor = WorkerExecutor::new(pool_size).unwrap();
{
let mut queue = executor.queue();
for _ in 0..4 * pool_size {
let workers_running = workers_running.clone();
let max_workers_running = max_workers_running.clone();
let _ = queue.submit(move || {
workers_running.fetch_add(1, Ordering::SeqCst);
thread::sleep(Duration::from_millis(100));
let _ = workers_running.fetch_update(
Ordering::SeqCst,
Ordering::SeqCst,
|running| {
let _ = max_workers_running.fetch_update(
Ordering::SeqCst,
Ordering::SeqCst,
|max| {
if running > max {
Some(running)
} else {
None
}
},
);
None
},
);
workers_running.fetch_sub(1, Ordering::SeqCst);
});
}
queue.join().expect("worker task failed :(");
}
let max_workers_running = max_workers_running.load(Ordering::Acquire);
println!("max running workers: {}", max_workers_running);
assert!(max_workers_running <= pool_size);
}
#[test]
fn only_correct_number_of_workers_run() {
test_executor_pool_size_ensured(1);
test_executor_pool_size_ensured(2);
test_executor_pool_size_ensured(4);
test_executor_pool_size_ensured(8);
}
#[test]
#[ignore]
fn can_stop_after_panic() {
let executor = WorkerExecutor::new(1).unwrap();
let job = executor.submit(|| panic!("WOOH I PANICKED")).unwrap();
job.join()
.expect_err("Should expect an error because a panic occurred");
println!("any panicked = {}", executor.any_panicked());
assert!(executor.any_panicked());
}
}