backup_deduplicator/pool.rs
1use std::sync::{Arc, mpsc, Mutex};
2use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender};
3use std::thread;
4use std::time::Duration;
5use log::{debug, error, trace, warn};
6
7/// A trait that must be implemented by a job type to be processed by the pool.
8pub trait JobTrait<T: Send = Self> {
9 /// Get the job id.
10 ///
11 /// # Returns
12 /// * `usize` - The job id.
13 fn job_id(&self) -> usize;
14}
15
16/// A trait that must be implemented by a result type to be returned by the pool.
17pub trait ResultTrait<T: Send = Self> {}
18
19/// Worker entry function signature
20/// The worker entry function is called by the worker thread to process a job.
21/// A custom worker must supply a function of this type to the thread pool to process jobs.
22///
23/// # Arguments
24/// * `usize` - The current worker id.
25/// * `Job` - The job received that should be processed.
26/// * `&Sender<Result>` - A sender to publish job results.
27/// * `&Sender<Job>` - A sender to publish new jobs to the thread pool.
28/// * `&mut Argument` - A mutable reference to the arguments passed to the worker thread via the thread pool creation.
29///
30/// # Returns
31/// * `()` - The worker entry function should not return a value but instead should send the result via the `Sender<Result>` back to the main thread.
32type WorkerEntry<Job, Result, Argument> = fn(usize, Job, &Sender<Result>, &Sender<Job>, &mut Argument);
33
34/// Internal worker struct to manage the worker thread via the thread pool.
35///
36/// # Fields
37/// * `id` - The worker id.
38/// * `thread` - The worker thread handle.
39struct Worker
40{
41 id: usize,
42 thread: Option<thread::JoinHandle<()>>,
43}
44
45impl Worker {
46 /// Create a new worker thread. Starts the worker thread and returns the worker struct.
47 ///
48 /// # Arguments
49 /// * `id` - The worker id.
50 /// * `job_receive` - A receiver to receive jobs from the thread pool.
51 /// * `result_publish` - A sender to publish job results.
52 /// * `job_publish` - A sender to publish new jobs to the thread pool.
53 /// * `func` - The worker entry function to process jobs.
54 /// * `arg` - The arguments passed to the worker thread via the thread pool creation.
55 ///
56 /// # Returns
57 /// * `Worker` - The worker struct with the worker thread handle.
58 fn new<Job: JobTrait + Send + 'static, Result: ResultTrait + Send + 'static, Argument: Send + 'static>(id: usize, job_receive: Arc<Mutex<Receiver<Job>>>, result_publish: Sender<Result>, job_publish: Sender<Job>, func: WorkerEntry<Job, Result, Argument>, arg: Argument) -> Worker {
59 let thread = thread::spawn(move || {
60 Worker::worker_entry(id, job_receive, result_publish, job_publish, func, arg);
61 });
62
63 Worker { id, thread: Some(thread) }
64 }
65
66 /// Function executed by the worker thread. Does exit when the job receiver is closed/the thread pool is shutting down.
67 ///
68 /// # Arguments
69 /// * `id` - The worker id.
70 /// * `job_receive` - A receiver to receive jobs from the thread pool.
71 /// * `result_publish` - A sender to publish job results.
72 /// * `job_publish` - A sender to publish new jobs to the thread pool.
73 /// * `func` - The worker entry function to process jobs.
74 /// * `arg` - The arguments passed to the worker thread via the thread pool creation.
75 fn worker_entry<Job: JobTrait + Send + 'static, Result: ResultTrait + Send + 'static, Argument: Send + 'static>(id: usize, job_receive: Arc<Mutex<Receiver<Job>>>, result_publish: Sender<Result>, job_publish: Sender<Job>, func: WorkerEntry<Job, Result, Argument>, mut arg: Argument) {
76 loop {
77 // Acquire the job lock
78 let job = job_receive.lock();
79
80 let job = match job {
81 Err(e) => {
82 error!("Worker {} shutting down {}", id, e);
83 break;
84 }
85 Ok(job) => {
86 job.recv() // receive new job
87 }
88 };
89
90 match job {
91 Err(_) => {
92 trace!("Worker {} shutting down", id);
93 break;
94 }
95 Ok(job) => {
96 trace!("Worker {} received job {}", id, job.job_id());
97 // Call the user function to process the job
98 func(id, job, &result_publish, &job_publish, &mut arg);
99 }
100 }
101 }
102 }
103}
104
105/// A thread pool to manage the distribution of jobs to worker threads.
106///
107/// # Template Parameters
108/// * `Job` - The job type that should be processed by the worker threads.
109/// * `Result` - The result type that should be returned by the worker threads.
110///
111/// Both `Job` and `Result` must implement the `Send` trait.
112pub struct ThreadPool<Job, Result>
113where
114 Job: Send,
115 Result: Send,
116{
117 workers: Vec<Worker>,
118 thread: Option<thread::JoinHandle<()>>,
119 job_publish: Arc<Mutex<Option<Sender<Job>>>>,
120 result_receive: Receiver<Result>,
121}
122
123impl<Job: Send + JobTrait + 'static, Result: Send + ResultTrait + 'static> ThreadPool<Job, Result> {
124 /// Create a new thread pool with a given number of worker threads (args.len()).
125 /// Each worker thread will receive an argument from the args vector. When a new job
126 /// is published to the thread pool, the thread pool will distribute the job to the worker threads
127 /// and execute the `func` function within a worker thread.
128 ///
129 /// # Arguments
130 /// * `args` - A vector of arguments that should be passed to the worker threads.
131 /// * `func` - The worker entry function to process jobs.
132 ///
133 /// # Returns
134 /// * `ThreadPool` - The thread pool struct with the worker threads.
135 ///
136 /// # Template Parameters
137 /// * `Argument` - The argument type that should be passed to the worker threads.
138 /// The argument type must implement the `Send` trait.
139 pub fn new<Argument: Send + 'static>(mut args: Vec<Argument>, func: WorkerEntry<Job, Result, Argument>) -> ThreadPool<Job, Result> {
140 assert!(args.len() > 0);
141
142 let mut workers = Vec::with_capacity(args.len());
143
144 let (job_publish, job_receive) = mpsc::channel();
145
146 let job_receive = Arc::new(Mutex::new(job_receive));
147 let (result_publish, result_receive) = mpsc::channel();
148 let (thread_publish_job, thread_receive_job) = mpsc::channel();
149
150 let mut id = 0;
151 while let Some(arg) = args.pop() {
152 workers.push(Worker::new(id, Arc::clone(&job_receive), result_publish.clone(), thread_publish_job.clone(), func, arg));
153 id += 1;
154 }
155
156 let job_publish = Arc::new(Mutex::new(Some(job_publish)));
157 let job_publish_clone = Arc::clone(&job_publish);
158
159 let thread = thread::spawn(move || {
160 ThreadPool::<Job, Result>::pool_entry(job_publish_clone, thread_receive_job);
161 });
162
163 ThreadPool {
164 workers,
165 job_publish,
166 result_receive,
167 thread: Some(thread),
168 }
169 }
170
171 /// Publish a new job to the thread pool. The job will be distributed to a worker thread.
172 ///
173 /// # Arguments
174 /// * `job` - The job that should be processed by a worker thread.
175 pub fn publish(&self, job: Job) {
176 let job_publish = self.job_publish.lock();
177 match job_publish {
178 Err(e) => {
179 error!("ThreadPool is shutting down. Cannot publish job. {}", e);
180 }
181 Ok(job_publish) => {
182 match job_publish.as_ref() {
183 None => {
184 error!("ThreadPool is shutting down. Cannot publish job.");
185 }
186 Some(job_publish) => {
187 match job_publish.send(job) {
188 Err(e) => {
189 error!("Failed to publish job on thread pool. {}", e);
190 }
191 Ok(_) => {}
192 }
193 }
194 }
195 }
196 }
197 }
198
199 /// Internal function that is run in a separate thread. It feeds back jobs from the worker threads to the input of the thread pool.
200 ///
201 /// # Arguments
202 /// * `job_publish` - A sender to publish new jobs to the thread pool.
203 /// * `job_receive` - A receiver to receive jobs from the worker threads.
204 fn pool_entry(job_publish: Arc<Mutex<Option<Sender<Job>>>>, job_receive: Receiver<Job>) {
205 loop {
206 let job = job_receive.recv();
207
208 match job {
209 Err(_) => {
210 trace!("Pool worker shutting down");
211 break;
212 }
213 Ok(job) => {
214 match job_publish.lock() {
215 Err(e) => {
216 error!("Pool worker shutting down: {}", e);
217 break;
218 }
219 Ok(job_publish) => {
220 if let Some(job_publish) = job_publish.as_ref() {
221 job_publish.send(job).expect("Pool worker failed to send job. This should never fail.");
222 }
223 }
224 }
225 }
226 }
227 }
228 }
229
230 /// Receive a result from the worker threads. This function will block until a result is available.
231 ///
232 /// # Returns
233 /// * `Result` - The result of a job processed by a worker thread.
234 ///
235 /// # Errors
236 /// * If all worker threads panicked, therefore the pipe is closed
237 pub fn receive(&self) -> std::result::Result<Result, mpsc::RecvError> {
238 self.result_receive.recv()
239 }
240
241 /// Receive a result from the worker threads. This function will block until a result is available or a timeout occurs.
242 ///
243 /// # Arguments
244 /// * `timeout` - The maximum time to wait for a result.
245 ///
246 /// # Returns
247 /// * `Result` - The result of a job processed by a worker thread.
248 ///
249 /// # Errors
250 /// * If all worker threads panicked, therefore the pipe is closed
251 /// * If the timeout occurs before a result is available
252 pub fn receive_timeout(&self, timeout: Duration) -> std::result::Result<Result, RecvTimeoutError> {
253 self.result_receive.recv_timeout(timeout)
254 }
255}
256
257impl<Job: Send, Result: Send> Drop for ThreadPool<Job, Result> {
258 fn drop(&mut self) {
259 drop(self.job_publish.lock().expect("This should not break").take());
260
261 for worker in &mut self.workers {
262 debug!("Shutting down worker {}", worker.id);
263
264 if let Some(thread) = worker.thread.take() {
265 match thread.join() {
266 Ok(_) => {
267 trace!("Worker {} shut down", worker.id);
268 }
269 Err(_) => {
270 warn!("Worker {} panicked", worker.id);
271 }
272 }
273 }
274 }
275
276 if let Some(thread) = self.thread.take() {
277 match thread.join() {
278 Ok(_) => {
279 trace!("ThreadPool shut down");
280 }
281 Err(_) => {
282 warn!("ThreadPool worker panicked");
283 }
284 }
285 }
286 }
287}