backup_deduplicator/
pool.rs

1use std::sync::{Arc, mpsc, Mutex};
2use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender};
3use std::thread;
4use std::time::Duration;
5use log::{debug, error, trace, warn};
6
7/// A trait that must be implemented by a job type to be processed by the pool.
8pub trait JobTrait<T: Send = Self> {
9    /// Get the job id.
10    /// 
11    /// # Returns
12    /// * `usize` - The job id.
13    fn job_id(&self) -> usize;
14}
15
16/// A trait that must be implemented by a result type to be returned by the pool.
17pub trait ResultTrait<T: Send = Self> {}
18
19/// Worker entry function signature
20/// The worker entry function is called by the worker thread to process a job.
21/// A custom worker must supply a function of this type to the thread pool to process jobs.
22/// 
23/// # Arguments
24/// * `usize` - The current worker id.
25/// * `Job` - The job received that should be processed.
26/// * `&Sender<Result>` - A sender to publish job results.
27/// * `&Sender<Job>` - A sender to publish new jobs to the thread pool.
28/// * `&mut Argument` - A mutable reference to the arguments passed to the worker thread via the thread pool creation.
29/// 
30/// # Returns
31/// * `()` - The worker entry function should not return a value but instead should send the result via the `Sender<Result>` back to the main thread.
32type WorkerEntry<Job, Result, Argument> = fn(usize, Job, &Sender<Result>, &Sender<Job>, &mut Argument);
33
34/// Internal worker struct to manage the worker thread via the thread pool.
35///
36/// # Fields
37/// * `id` - The worker id.
38/// * `thread` - The worker thread handle.
39struct Worker
40{
41    id: usize,
42    thread: Option<thread::JoinHandle<()>>,
43}
44
45impl Worker {
46    /// Create a new worker thread. Starts the worker thread and returns the worker struct.
47    /// 
48    /// # Arguments
49    /// * `id` - The worker id.
50    /// * `job_receive` - A receiver to receive jobs from the thread pool.
51    /// * `result_publish` - A sender to publish job results.
52    /// * `job_publish` - A sender to publish new jobs to the thread pool.
53    /// * `func` - The worker entry function to process jobs.
54    /// * `arg` - The arguments passed to the worker thread via the thread pool creation.
55    /// 
56    /// # Returns
57    /// * `Worker` - The worker struct with the worker thread handle.
58    fn new<Job: JobTrait + Send + 'static, Result: ResultTrait + Send + 'static, Argument: Send + 'static>(id: usize, job_receive: Arc<Mutex<Receiver<Job>>>, result_publish: Sender<Result>, job_publish: Sender<Job>, func: WorkerEntry<Job, Result, Argument>, arg: Argument) -> Worker {
59        let thread = thread::spawn(move || {
60            Worker::worker_entry(id, job_receive, result_publish, job_publish, func, arg);
61        });
62
63        Worker { id, thread: Some(thread) }
64    }
65
66    /// Function executed by the worker thread. Does exit when the job receiver is closed/the thread pool is shutting down.
67    /// 
68    /// # Arguments
69    /// * `id` - The worker id.
70    /// * `job_receive` - A receiver to receive jobs from the thread pool.
71    /// * `result_publish` - A sender to publish job results.
72    /// * `job_publish` - A sender to publish new jobs to the thread pool.
73    /// * `func` - The worker entry function to process jobs.
74    /// * `arg` - The arguments passed to the worker thread via the thread pool creation.
75    fn worker_entry<Job: JobTrait + Send + 'static, Result: ResultTrait + Send + 'static, Argument: Send + 'static>(id: usize, job_receive: Arc<Mutex<Receiver<Job>>>, result_publish: Sender<Result>, job_publish: Sender<Job>, func: WorkerEntry<Job, Result, Argument>, mut arg: Argument) {
76        loop {
77            // Acquire the job lock
78            let job = job_receive.lock();
79
80            let job = match job {
81                Err(e) => {
82                    error!("Worker {} shutting down {}", id, e);
83                    break;
84                }
85                Ok(job) => {
86                    job.recv() // receive new job
87                }
88            };
89
90            match job {
91                Err(_) => {
92                    trace!("Worker {} shutting down", id);
93                    break;
94                }
95                Ok(job) => {
96                    trace!("Worker {} received job {}", id, job.job_id());
97                    // Call the user function to process the job
98                    func(id, job, &result_publish, &job_publish, &mut arg);
99                }
100            }
101        }
102    }
103}
104
105/// A thread pool to manage the distribution of jobs to worker threads.
106/// 
107/// # Template Parameters
108/// * `Job` - The job type that should be processed by the worker threads.
109/// * `Result` - The result type that should be returned by the worker threads.
110/// 
111/// Both `Job` and `Result` must implement the `Send` trait.
112pub struct ThreadPool<Job, Result>
113where
114    Job: Send,
115    Result: Send,
116{
117    workers: Vec<Worker>,
118    thread: Option<thread::JoinHandle<()>>,
119    job_publish: Arc<Mutex<Option<Sender<Job>>>>,
120    result_receive: Receiver<Result>,
121}
122
123impl<Job: Send + JobTrait + 'static, Result: Send + ResultTrait + 'static> ThreadPool<Job, Result> {
124    /// Create a new thread pool with a given number of worker threads (args.len()).
125    /// Each worker thread will receive an argument from the args vector. When a new job
126    /// is published to the thread pool, the thread pool will distribute the job to the worker threads
127    /// and execute the `func` function within a worker thread.
128    /// 
129    /// # Arguments
130    /// * `args` - A vector of arguments that should be passed to the worker threads.
131    /// * `func` - The worker entry function to process jobs.
132    /// 
133    /// # Returns
134    /// * `ThreadPool` - The thread pool struct with the worker threads.
135    /// 
136    /// # Template Parameters
137    /// * `Argument` - The argument type that should be passed to the worker threads.
138    /// The argument type must implement the `Send` trait.
139    pub fn new<Argument: Send + 'static>(mut args: Vec<Argument>, func: WorkerEntry<Job, Result, Argument>) -> ThreadPool<Job, Result> {
140        assert!(args.len() > 0);
141
142        let mut workers = Vec::with_capacity(args.len());
143
144        let (job_publish, job_receive) = mpsc::channel();
145
146        let job_receive = Arc::new(Mutex::new(job_receive));
147        let (result_publish, result_receive) = mpsc::channel();
148        let (thread_publish_job, thread_receive_job) = mpsc::channel();
149
150        let mut id = 0;
151        while let Some(arg) = args.pop() {
152            workers.push(Worker::new(id, Arc::clone(&job_receive), result_publish.clone(), thread_publish_job.clone(), func, arg));
153            id += 1;
154        }
155
156        let job_publish = Arc::new(Mutex::new(Some(job_publish)));
157        let job_publish_clone = Arc::clone(&job_publish);
158
159        let thread = thread::spawn(move || {
160            ThreadPool::<Job, Result>::pool_entry(job_publish_clone, thread_receive_job);
161        });
162
163        ThreadPool {
164            workers,
165            job_publish,
166            result_receive,
167            thread: Some(thread),
168        }
169    }
170    
171    /// Publish a new job to the thread pool. The job will be distributed to a worker thread.
172    /// 
173    /// # Arguments
174    /// * `job` - The job that should be processed by a worker thread.
175    pub fn publish(&self, job: Job) {
176        let job_publish = self.job_publish.lock();
177        match job_publish {
178            Err(e) => {
179                error!("ThreadPool is shutting down. Cannot publish job. {}", e);
180            }
181            Ok(job_publish) => {
182                match job_publish.as_ref() {
183                    None => {
184                        error!("ThreadPool is shutting down. Cannot publish job.");
185                    }
186                    Some(job_publish) => {
187                        match job_publish.send(job) {
188                            Err(e) => {
189                                error!("Failed to publish job on thread pool. {}", e);
190                            }
191                            Ok(_) => {}
192                        }
193                    }
194                }
195            }
196        }
197    }
198
199    /// Internal function that is run in a separate thread. It feeds back jobs from the worker threads to the input of the thread pool.
200    /// 
201    /// # Arguments
202    /// * `job_publish` - A sender to publish new jobs to the thread pool.
203    /// * `job_receive` - A receiver to receive jobs from the worker threads.
204    fn pool_entry(job_publish: Arc<Mutex<Option<Sender<Job>>>>, job_receive: Receiver<Job>) {
205        loop {
206            let job = job_receive.recv();
207
208            match job {
209                Err(_) => {
210                    trace!("Pool worker shutting down");
211                    break;
212                }
213                Ok(job) => {
214                    match job_publish.lock() {
215                        Err(e) => {
216                            error!("Pool worker shutting down: {}", e);
217                            break;
218                        }
219                        Ok(job_publish) => {
220                            if let Some(job_publish) = job_publish.as_ref() {
221                                job_publish.send(job).expect("Pool worker failed to send job. This should never fail.");
222                            }
223                        }
224                    }
225                }
226            }
227        }
228    }
229    
230    /// Receive a result from the worker threads. This function will block until a result is available.
231    /// 
232    /// # Returns
233    /// * `Result` - The result of a job processed by a worker thread.
234    /// 
235    /// # Errors
236    /// * If all worker threads panicked, therefore the pipe is closed
237    pub fn receive(&self) -> std::result::Result<Result, mpsc::RecvError> {
238        self.result_receive.recv()
239    }
240
241    /// Receive a result from the worker threads. This function will block until a result is available or a timeout occurs.
242    /// 
243    /// # Arguments
244    /// * `timeout` - The maximum time to wait for a result.
245    /// 
246    /// # Returns
247    /// * `Result` - The result of a job processed by a worker thread.
248    /// 
249    /// # Errors
250    /// * If all worker threads panicked, therefore the pipe is closed
251    /// * If the timeout occurs before a result is available
252    pub fn receive_timeout(&self, timeout: Duration) -> std::result::Result<Result, RecvTimeoutError> {
253        self.result_receive.recv_timeout(timeout)
254    }
255}
256
257impl<Job: Send, Result: Send> Drop for ThreadPool<Job, Result> {
258    fn drop(&mut self) {
259        drop(self.job_publish.lock().expect("This should not break").take());
260
261        for worker in &mut self.workers {
262            debug!("Shutting down worker {}", worker.id);
263
264            if let Some(thread) = worker.thread.take() {
265                match thread.join() {
266                    Ok(_) => {
267                        trace!("Worker {} shut down", worker.id);
268                    }
269                    Err(_) => {
270                        warn!("Worker {} panicked", worker.id);
271                    }
272                }
273            }
274        }
275
276        if let Some(thread) = self.thread.take() {
277            match thread.join() {
278                Ok(_) => {
279                    trace!("ThreadPool shut down");
280                }
281                Err(_) => {
282                    warn!("ThreadPool worker panicked");
283                }
284            }
285        }
286    }
287}