use crossbeam_channel::{Receiver, Sender, unbounded};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
pub struct Worker<I, O>
where
I: Send + 'static,
O: Send + 'static,
{
result_rx: Receiver<Result<O, String>>,
task_tx: Sender<I>,
pending_tasks: Arc<AtomicUsize>,
completed_tasks: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
submitting_finished: Arc<AtomicBool>,
}
impl<I, O> Worker<I, O>
where
I: Send + 'static,
O: Send + 'static,
{
pub fn new<F>(processor: F) -> Self
where
F: Fn(I) -> Result<O, String> + Send + Sync + 'static,
{
let (task_tx, task_rx) = unbounded::<I>();
let (result_tx, result_rx) = unbounded::<Result<O, String>>();
let shutdown = Arc::new(AtomicBool::new(false));
let pending_tasks = Arc::new(AtomicUsize::new(0));
let completed_tasks = Arc::new(AtomicUsize::new(0));
let submitting_finished = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
let pending_clone = pending_tasks.clone();
let completed_clone = completed_tasks.clone();
let submitting_clone = submitting_finished.clone();
let processor = Arc::new(processor);
std::thread::spawn(move || {
while !shutdown_clone.load(Ordering::Relaxed) {
match task_rx.recv_timeout(Duration::from_millis(100)) {
Ok(input) => {
let tx = result_tx.clone();
let completed = completed_clone.clone();
let processor = processor.clone();
rayon::spawn(move || {
let result = processor(input);
let _ = tx.send(result);
completed.fetch_add(1, Ordering::Relaxed);
});
}
Err(_) => {
if submitting_clone.load(Ordering::Relaxed) {
let pending = pending_clone.load(Ordering::Relaxed);
let completed = completed_clone.load(Ordering::Relaxed);
if pending > 0 && pending == completed {
shutdown_clone.store(true, Ordering::Relaxed);
break;
}
}
continue;
}
}
}
});
Self {
result_rx,
task_tx,
pending_tasks,
completed_tasks,
shutdown,
submitting_finished,
}
}
pub fn submit(&self, input: I) -> Result<(), String> {
if self.submitting_finished.load(Ordering::Relaxed) {
return Err("Cannot submit after finish_submitting() was called".to_string());
}
self.task_tx
.send(input)
.map_err(|e| format!("Failed to submit task: {}", e))?;
self.pending_tasks.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn finish_submitting(&self) {
self.submitting_finished.store(true, Ordering::Relaxed);
}
pub fn poll_results(&self) -> Vec<Result<O, String>> {
let mut results = Vec::new();
while let Ok(result) = self.result_rx.try_recv() {
results.push(result);
}
results
}
pub fn is_complete(&self) -> bool {
if !self.submitting_finished.load(Ordering::Relaxed) {
return false;
}
let pending = self.pending_tasks.load(Ordering::Relaxed);
let completed = self.completed_tasks.load(Ordering::Relaxed);
pending > 0 && pending == completed
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
}
}
impl<I, O> Drop for Worker<I, O>
where
I: Send + 'static,
O: Send + 'static,
{
fn drop(&mut self) {
self.shutdown();
}
}