use crate::constants::MAX_TASK_QUEUE_SIZE;
use crate::utils::lock::lock_or_recover;
use std::collections::{HashMap, VecDeque};
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
pub type TaskId = String;
#[derive(Debug)]
pub struct TaskResult<T> {
pub id: TaskId,
pub result: Result<T, String>,
}
struct WorkItem<T> {
id: TaskId,
task: Box<dyn FnOnce() -> T + Send + 'static>,
}
struct ResultMessage<T> {
id: TaskId,
result: Result<T, String>,
}
struct Worker {
_handle: JoinHandle<()>,
}
impl Worker {
fn new<T: Send + 'static>(
work_rx: Arc<Mutex<Receiver<WorkItem<T>>>>,
result_tx: Sender<ResultMessage<T>>,
) -> Self {
let handle = thread::spawn(move || {
loop {
let work_item = {
let rx = lock_or_recover(&work_rx);
rx.recv()
};
match work_item {
Ok(item) => {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
(item.task)()
}));
let msg = match result {
Ok(value) => ResultMessage {
id: item.id,
result: Ok(value),
},
Err(e) => ResultMessage {
id: item.id,
result: Err(format!("Task panicked: {:?}", e)),
},
};
let _ = result_tx.send(msg);
}
Err(_) => {
break;
}
}
}
});
Worker { _handle: handle }
}
}
pub struct PooledTaskRunner<T: Send + 'static> {
work_tx: mpsc::SyncSender<WorkItem<T>>,
_work_rx: Arc<Mutex<Receiver<WorkItem<T>>>>,
result_rx: Receiver<ResultMessage<T>>,
_result_tx: Sender<ResultMessage<T>>,
_workers: Vec<Worker>,
pending: HashMap<TaskId, ()>,
_queue: VecDeque<(TaskId, Box<dyn FnOnce() -> T + Send + 'static>)>,
}
impl<T: Send + 'static> PooledTaskRunner<T> {
pub fn new(num_workers: usize) -> Self {
assert!(num_workers > 0, "Must have at least 1 worker");
let (work_tx, work_rx) = mpsc::sync_channel(MAX_TASK_QUEUE_SIZE);
let (result_tx, result_rx) = mpsc::channel();
let work_rx = Arc::new(Mutex::new(work_rx));
let mut workers = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
workers.push(Worker::new(work_rx.clone(), result_tx.clone()));
}
Self {
work_tx,
_work_rx: work_rx,
result_rx,
_result_tx: result_tx,
_workers: workers,
pending: HashMap::new(),
_queue: VecDeque::new(),
}
}
pub fn spawn<F>(&mut self, id: impl Into<TaskId>, task: F)
where
F: FnOnce() -> T + Send + 'static,
{
let id = id.into();
if self.pending.contains_key(&id) {
return; }
self.pending.insert(id.clone(), ());
let work_item = WorkItem {
id,
task: Box::new(task),
};
let _ = self.work_tx.try_send(work_item);
}
pub fn spawn_result<F, E>(&mut self, id: impl Into<TaskId>, task: F)
where
F: FnOnce() -> Result<T, E> + Send + 'static,
E: std::fmt::Display,
{
self.spawn(id, move || match task() {
Ok(value) => value,
Err(e) => panic!("Task error: {}", e),
});
}
pub fn poll(&mut self) -> Option<TaskResult<T>> {
match self.result_rx.try_recv() {
Ok(msg) => {
self.pending.remove(&msg.id);
Some(TaskResult {
id: msg.id,
result: msg.result,
})
}
Err(_) => None,
}
}
pub fn has_pending(&self) -> bool {
!self.pending.is_empty()
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn is_pending(&self, id: &str) -> bool {
self.pending.contains_key(id)
}
}
impl<T: Send + 'static> Drop for PooledTaskRunner<T> {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_pooled_runner_basic() {
let mut runner = PooledTaskRunner::new(2);
runner.spawn("task1", || 42);
runner.spawn("task2", || 100);
assert!(runner.has_pending());
let mut results = Vec::new();
for _ in 0..2 {
while results.len() < 2 {
if let Some(result) = runner.poll() {
results.push(result);
}
thread::sleep(Duration::from_millis(10));
}
}
assert_eq!(results.len(), 2);
assert!(!runner.has_pending());
}
#[test]
fn test_pooled_runner_many_tasks() {
let mut runner = PooledTaskRunner::new(4);
let counter = Arc::new(AtomicUsize::new(0));
for i in 0..100 {
let counter = counter.clone();
runner.spawn(format!("task_{}", i), move || {
counter.fetch_add(1, Ordering::SeqCst);
i
});
}
let mut results = Vec::new();
while results.len() < 100 {
if let Some(result) = runner.poll() {
assert!(result.result.is_ok());
results.push(result);
}
thread::sleep(Duration::from_millis(1));
}
assert_eq!(results.len(), 100);
assert_eq!(counter.load(Ordering::SeqCst), 100);
assert!(!runner.has_pending());
}
#[test]
fn test_pooled_runner_duplicate_id() {
let mut runner = PooledTaskRunner::new(2);
runner.spawn("duplicate", || 1);
runner.spawn("duplicate", || 2);
thread::sleep(Duration::from_millis(100));
let mut count = 0;
while let Some(_result) = runner.poll() {
count += 1;
}
assert_eq!(count, 1); }
#[test]
fn test_pooled_runner_panic_handling() {
let mut runner = PooledTaskRunner::<i32>::new(2);
runner.spawn("panic_task", || {
panic!("Test panic");
});
thread::sleep(Duration::from_millis(100));
if let Some(result) = runner.poll() {
assert!(result.result.is_err());
assert!(result.result.unwrap_err().contains("panicked"));
} else {
panic!("Should have received error result");
}
}
#[test]
fn test_pooled_runner_bounded_concurrency() {
let runner = PooledTaskRunner::<()>::new(2);
assert_eq!(runner._workers.len(), 2);
}
#[test]
fn test_pooled_runner_pending_count() {
let mut runner = PooledTaskRunner::new(2);
assert_eq!(runner.pending_count(), 0);
runner.spawn("task1", || 42);
runner.spawn("task2", || 100);
runner.spawn("task3", || 200);
assert_eq!(runner.pending_count(), 3);
while runner.poll().is_none() {
thread::sleep(Duration::from_millis(10));
}
assert!(runner.pending_count() < 3);
}
#[test]
fn test_pooled_runner_is_pending() {
let mut runner = PooledTaskRunner::new(2);
assert!(!runner.is_pending("task1"));
runner.spawn("task1", || 42);
assert!(runner.is_pending("task1"));
assert!(!runner.is_pending("task2"));
while runner.poll().is_none() {
thread::sleep(Duration::from_millis(10));
}
assert!(!runner.is_pending("task1"));
}
#[test]
fn test_pooled_runner_spawn_result_ok() {
let mut runner = PooledTaskRunner::new(2);
runner.spawn_result("task1", || Ok::<i32, &str>(42));
while let Some(result) = runner.poll() {
assert_eq!(result.id, "task1");
assert!(result.result.is_ok());
assert_eq!(result.result.unwrap(), 42);
}
}
#[test]
fn test_pooled_runner_spawn_result_err() {
let mut runner = PooledTaskRunner::new(2);
runner.spawn_result("task1", || Err::<i32, &str>("error"));
while let Some(result) = runner.poll() {
assert_eq!(result.id, "task1");
assert!(result.result.is_err());
}
}
#[test]
fn test_pooled_runner_no_workers_panics() {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _runner = PooledTaskRunner::<i32>::new(0);
}));
assert!(result.is_err());
}
#[test]
fn test_pooled_runner_poll_empty() {
let mut runner = PooledTaskRunner::<i32>::new(2);
assert!(runner.poll().is_none());
}
#[test]
fn test_task_result_fields() {
let result = TaskResult {
id: "test".to_string(),
result: Ok(42),
};
assert_eq!(result.id, "test");
assert_eq!(result.result.unwrap(), 42);
}
}