use core::num;
use std::{
cell::{Cell, UnsafeCell},
sync::{
atomic::Ordering,
atomic::{AtomicBool, AtomicUsize},
Arc,
}, ptr::null,
};
use parking_lot::{Condvar, Mutex};
use super::semaphore::Sem;
pub trait WorkerTask: Send + Sync {
fn name(&self) -> &str;
fn work(&self, worker_id: usize);
}
impl WorkerTask for () {
fn name(&self) -> &str {
"no-op"
}
fn work(&self, worker_id: usize) {
let _ = worker_id;
}
}
pub struct WorkerTaskDispatcher {
task: UnsafeCell<*const dyn WorkerTask>,
started: AtomicUsize,
not_finished: AtomicUsize,
start_semaphore: Sem,
end_semaphore: Sem,
shut_down: AtomicBool,
did_shutdown: AtomicUsize,
mutex: Mutex<bool>,
cv: Condvar,
}
impl WorkerTaskDispatcher {
fn new() -> Self {
Self {
task: UnsafeCell::new(null::<()>()),
started: AtomicUsize::new(0),
not_finished: AtomicUsize::new(0),
start_semaphore: Sem::new(0).unwrap(),
end_semaphore: Sem::new(0).unwrap(),
shut_down: AtomicBool::new(false),
did_shutdown: AtomicUsize::new(0),
mutex: Mutex::new(false),
cv: Condvar::new(),
}
}
fn coordinator_distribute_task<'a>(&self, task: &'a dyn WorkerTask, num_workers: usize) {
unsafe {
self.task.get().write(std::mem::transmute(task));
}
self.not_finished.store(num_workers, Ordering::Relaxed);
self.start_semaphore.signal(num_workers);
self.end_semaphore.wait();
unsafe {
let p = &mut *self.task.get();
*p = null::<()>();
}
}
fn worker_run_task(&self) {
self.start_semaphore.wait();
let worker_id = self.started.fetch_add(1, Ordering::AcqRel);
WorkerThread::set_worker_id(worker_id);
unsafe {
let task_ptr = self.task.get();
if task_ptr.read().is_null() {
return;
}
let task_ref = &**task_ptr;
task_ref.work(worker_id);
}
let not_finished = self.not_finished.fetch_sub(1, Ordering::AcqRel);
if not_finished == 1 {
self.end_semaphore.signal(1);
}
}
}
pub struct WorkerThreads {
name: &'static str,
workers: Vec<Arc<WorkerThread>>,
max_workers: usize,
created_workers: usize,
active_workers: usize,
dispatcher: Arc<WorkerTaskDispatcher>,
}
impl WorkerThreads {
pub fn active_workers(&self) -> usize {
self.active_workers
}
pub fn new(name: &'static str, max_workers: usize) -> Self {
Self {
name,
workers: Vec::with_capacity(max_workers),
created_workers: 0,
active_workers: 0,
dispatcher: Arc::new(WorkerTaskDispatcher::new()),
max_workers
}
}
pub fn initialize_workers(&mut self) {
self.set_active_workers(self.max_workers);
}
fn create_worker(&mut self) {
let worker = Arc::new(WorkerThread {
dispatcher: self.dispatcher.clone(),
});
self.workers.push(worker.clone());
let _ = std::thread::spawn(move || {
worker.run();
});
let mut started = self.dispatcher.mutex.lock();
if !*started {
self.dispatcher.cv.wait(&mut started);
}
*started = false;
}
pub fn set_active_workers(&mut self, num_workers: usize) -> usize {
while self.created_workers < num_workers {
self.create_worker();
self.created_workers += 1;
}
self.active_workers = self.created_workers.min(num_workers);
self.active_workers
}
pub fn run_task<'a>(&self, task: &'a dyn WorkerTask) {
self.dispatcher.coordinator_distribute_task(task, self.active_workers);
}
pub fn run_task_wworkers<'a>(&mut self, task: &'a dyn WorkerTask, num_workers: usize) {
let x = self.active_workers;
self.set_active_workers(num_workers);
self.dispatcher.coordinator_distribute_task(task, num_workers);
self.set_active_workers(x);
}
}
thread_local! {
static WORKER_ID: Cell<usize> = Cell::new(usize::MAX);
}
pub struct WorkerThread {
dispatcher: Arc<WorkerTaskDispatcher>,
}
unsafe impl Send for WorkerThread {}
unsafe impl Sync for WorkerThread {}
impl WorkerThread {
pub fn worker_id() -> usize {
WORKER_ID.with(|id| id.get())
}
fn set_worker_id(id_: usize) {
WORKER_ID.with(|id| id.set(id_));
}
fn run(&self) {
{
let mut started = self.dispatcher.mutex.lock();
*started = true;
self.dispatcher.cv.notify_one();
}
while !self.dispatcher.shut_down.load(Ordering::Relaxed) {
self.dispatcher.worker_run_task();
}
self.dispatcher.did_shutdown.fetch_sub(1, Ordering::AcqRel);
}
}