use super::{Priority, WorkerConfig, WorkerHandle};
use std::collections::BinaryHeap;
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};
use crate::utils::lock as lock_util;
struct SharedState {
queue: TaskQueue,
shutdown: bool,
}
pub struct WorkerPool {
workers: Vec<Worker>,
state: Arc<(Mutex<SharedState>, Condvar)>,
config: WorkerConfig,
}
impl WorkerPool {
pub fn new(threads: usize) -> Self {
Self::with_config(WorkerConfig::with_threads(threads))
}
pub fn with_config(config: WorkerConfig) -> Self {
let shared_state = SharedState {
queue: TaskQueue::new(config.queue_capacity),
shutdown: false,
};
let state = Arc::new((Mutex::new(shared_state), Condvar::new()));
let mut workers = Vec::with_capacity(config.threads);
for id in 0..config.threads {
workers.push(Worker::new(id, state.clone()));
}
Self {
workers,
state,
config,
}
}
pub fn spawn_blocking<F, T>(&self, f: F) -> WorkerHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
WorkerHandle::spawn_blocking(f)
}
pub fn submit<F>(&self, f: F) -> bool
where
F: FnOnce() + Send + 'static,
{
self.submit_with_priority(f, Priority::Normal)
}
pub fn submit_with_priority<F>(&self, f: F, priority: Priority) -> bool
where
F: FnOnce() + Send + 'static,
{
let (lock, cvar) = &*self.state;
let mut state = lock_util::lock_or_recover(lock);
if state.shutdown {
return false;
}
let seq = state.queue.next_seq;
state.queue.next_seq = state.queue.next_seq.wrapping_add(1);
let task = QueuedTask::new(f, priority, seq);
if state.queue.push(task) {
cvar.notify_one();
return true;
}
false
}
pub fn active_workers(&self) -> usize {
self.workers.iter().filter(|w| w.is_active()).count()
}
pub fn queue_len(&self) -> usize {
let (lock, _) = &*self.state;
lock_util::lock_or_recover(lock).queue.len()
}
pub fn thread_count(&self) -> usize {
self.config.threads
}
pub fn shutdown(&self) {
let (lock, cvar) = &*self.state;
let mut state = lock_util::lock_or_recover(lock);
state.shutdown = true;
cvar.notify_all();
}
pub fn is_shutdown(&self) -> bool {
let (lock, _) = &*self.state;
lock_util::lock_or_recover(lock).shutdown
}
}
impl Default for WorkerPool {
fn default() -> Self {
Self::with_config(WorkerConfig::default())
}
}
impl Drop for WorkerPool {
fn drop(&mut self) {
self.shutdown();
for worker in &mut self.workers {
worker.join();
}
}
}
pub struct Worker {
id: usize,
thread: Option<JoinHandle<()>>,
active: Arc<Mutex<bool>>,
}
impl Worker {
fn new(id: usize, state: Arc<(Mutex<SharedState>, Condvar)>) -> Self {
let active = Arc::new(Mutex::new(true));
let active_clone = active.clone();
let thread = thread::Builder::new()
.name(format!("revue-worker-{}", id))
.spawn(move || {
let (lock, cvar) = &*state;
loop {
let task = {
let mut state = lock_util::lock_or_recover(lock);
while state.queue.is_empty() && !state.shutdown {
state = cvar.wait(state).unwrap_or_else(|poisoned| {
log_warn!("Condvar wait was poisoned, recovering");
poisoned.into_inner()
});
}
if state.shutdown && state.queue.is_empty() {
break;
}
state.queue.pop()
};
if let Some(queued_task) = task {
(queued_task.task)();
}
}
if let Ok(mut active) = active_clone.lock() {
*active = false;
}
})
.ok();
Self { id, thread, active }
}
pub fn id(&self) -> usize {
self.id
}
pub fn is_active(&self) -> bool {
*lock_util::lock_or_recover(&self.active)
}
pub fn join(&mut self) {
if let Some(thread) = self.thread.take() {
let _ = thread.join();
}
}
}
struct TaskQueue {
tasks: BinaryHeap<QueuedTask>,
capacity: usize,
next_seq: u64,
}
impl TaskQueue {
fn new(capacity: usize) -> Self {
Self {
tasks: BinaryHeap::with_capacity(capacity),
capacity,
next_seq: 0,
}
}
fn len(&self) -> usize {
self.tasks.len()
}
fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
fn push(&mut self, task: QueuedTask) -> bool {
if self.tasks.len() >= self.capacity {
return false;
}
self.tasks.push(task);
true
}
fn pop(&mut self) -> Option<QueuedTask> {
self.tasks.pop()
}
}
struct QueuedTask {
task: Box<dyn FnOnce() + Send + 'static>,
priority: Priority,
seq: u64,
}
impl QueuedTask {
fn new<F>(task: F, priority: Priority, seq: u64) -> Self
where
F: FnOnce() + Send + 'static,
{
Self {
task: Box::new(task),
priority,
seq,
}
}
}
impl PartialEq for QueuedTask {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority && self.seq == other.seq
}
}
impl Eq for QueuedTask {}
impl PartialOrd for QueuedTask {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for QueuedTask {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.priority.cmp(&other.priority) {
std::cmp::Ordering::Equal => {
other.seq.cmp(&self.seq)
}
other => other,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_worker_pool_new() {
let pool = WorkerPool::new(4);
assert_eq!(pool.thread_count(), 4);
}
#[test]
fn test_worker_pool_default() {
let pool = WorkerPool::default();
assert!(pool.thread_count() >= 1);
}
#[test]
fn test_worker_pool_shutdown() {
let pool = WorkerPool::new(2);
assert!(!pool.is_shutdown());
pool.shutdown();
assert!(pool.is_shutdown());
}
#[test]
fn test_worker_pool_submit() {
let pool = WorkerPool::new(2);
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..10 {
let counter = counter.clone();
assert!(pool.submit(move || {
counter.fetch_add(1, Ordering::SeqCst);
}));
}
thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(counter.load(Ordering::SeqCst), 10);
pool.shutdown();
}
#[test]
fn test_worker_pool_submit_after_shutdown() {
let pool = WorkerPool::new(1);
pool.shutdown();
assert!(!pool.submit(|| {}));
}
#[test]
fn test_worker_pool_priority() {
use std::sync::atomic::{AtomicBool, Ordering};
let pool = WorkerPool::new(1);
let order = Arc::new(Mutex::new(Vec::new()));
let barrier = Arc::new(AtomicBool::new(false));
let barrier_clone = barrier.clone();
pool.submit(move || {
while !barrier_clone.load(Ordering::SeqCst) {
thread::sleep(std::time::Duration::from_millis(1));
}
});
thread::sleep(std::time::Duration::from_millis(10));
let order1 = order.clone();
pool.submit_with_priority(
move || {
let mut order = lock_util::lock_or_recover(&order1);
order.push("low");
},
Priority::Low,
);
let order2 = order.clone();
pool.submit_with_priority(
move || {
let mut order = lock_util::lock_or_recover(&order2);
order.push("high");
},
Priority::High,
);
barrier.store(true, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_millis(100));
pool.shutdown();
let result = lock_util::lock_or_recover(&order);
assert_eq!(result.len(), 2);
assert_eq!(result[0], "high");
assert_eq!(result[1], "low");
}
}