parallel_worker/workers/
cancelable_worker.rs

1use std::sync::mpsc::Sender;
2
3use crate::{
4    internal::{TaskQueue, Work}, prelude::State, worker_traits::{WorkerInit, WorkerMethods}
5};
6
7use super::BasicWorker;
8
9/// A worker that processes tasks in parallel using multiple worker threads.
10/// Allows for optional results and task cancelation.
11pub struct CancelableWorker<T, R>
12where
13    T: Send + 'static,
14    R: Send + 'static,
15{
16    inner: BasicWorker<T, Option<R>>,
17    worker_state: Vec<State>,
18}
19
20impl<T, R> WorkerMethods<T, R> for CancelableWorker<T, R>
21where
22    T: Send + 'static,
23    R: Send + 'static,
24{
25    fn add_task(&self, task: T) {
26        self.inner.add_task(task);
27    }
28
29    fn add_tasks(&self, tasks: impl IntoIterator<Item = T>) {
30        self.inner.add_tasks(tasks);
31    }
32
33    /// Clear the task queue and cancel all ongoing tasks as soon as possible.
34    /// The results of canceled tasks will be discarded. Results of already completed tasks will remain unaffected.
35    /// Canceling tasks during their execution requires the worker function to use the [`crate::check_if_cancelled!`] macro.
36    fn cancel_tasks(&self) {
37        self.inner.cancel_tasks();
38        for state in &self.worker_state {
39            state.cancel();
40        }
41    }
42
43    fn get(&self) -> Option<R> {
44        self.inner.get_iter().flatten().next()
45    }
46
47    fn get_blocking(&self) -> Option<R> {
48        self.inner.get_iter_blocking().flatten().next()
49    }
50
51    fn pending_tasks(&self) -> usize {
52        self.inner.pending_tasks()
53    }
54}
55
56impl<T, R, F> WorkerInit<T, R, F> for CancelableWorker<T, R>
57where
58    T: Send + 'static,
59    R: Send + 'static,
60    F: Fn(T, &State) -> Option<R> + Copy + Send + 'static,
61{
62    fn with_num_threads(num_worker_threads: usize, worker_function: F) -> Self {
63        let (result_sender, result_receiver) = std::sync::mpsc::channel();
64        let task_queue = TaskQueue::new();
65
66        let mut worker_state = Vec::with_capacity(num_worker_threads);
67        for _ in 0..num_worker_threads {
68            let state = State::new();
69            spawn_worker_thread(
70                worker_function,
71                result_sender.clone(),
72                task_queue.clone(),
73                state.clone(),
74            );
75            worker_state.push(state);
76        }
77
78        CancelableWorker {
79            worker_state,
80            inner: BasicWorker::constructor(task_queue, result_receiver, num_worker_threads),
81        }
82    }
83}
84
85fn spawn_worker_thread<T, R, F>(
86    worker_function: F,
87    result_sender: Sender<Option<R>>,
88    task_queue: TaskQueue<Work<T>>,
89    state: State,
90) where
91    T: Send + 'static,
92    R: Send + 'static,
93    F: Fn(T, &State) -> Option<R> + Send + 'static,
94{
95    std::thread::spawn(move || {
96        loop {
97            match task_queue.wait_for_task_and_then(|| state.set_running()) {
98                Work::Terminate => break,
99                Work::Task(task) => {
100                    let result = worker_function(task, &state);
101                    let result = if state.is_cancelled() { None } else { result };
102
103                    if let Err(_) = result_sender.send(result) {
104                        break;
105                    }
106                }
107            }
108        }
109    });
110}
111
112impl<T, R> Drop for CancelableWorker<T, R>
113where
114    T: Send + 'static,
115    R: Send + 'static,
116{
117    /// Drop the worker and terminate all worker threads. Cancel all tasks as soon as possible.
118    fn drop(&mut self) {
119        for state in &self.worker_state {
120            state.cancel();
121        }
122    }
123}