jlizard_simple_threadpool/
worker.rs

1//! Worker model for concurrent jobs handling
2use crate::common::Job;
3#[cfg(feature = "log")]
4use log::debug;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6use std::sync::mpsc::Receiver;
7use std::sync::{Arc, Mutex};
8use std::thread::JoinHandle;
9use std::time::Duration;
10
11pub struct Worker {
12    pub(super) id: u8,
13    pub(super) thread: Option<JoinHandle<()>>,
14}
15
16impl Worker {
17    /// Creates a new worker that spawns a thread to process jobs from the shared receiver.
18    ///
19    /// The worker continuously receives jobs from the channel until the sender is dropped
20    /// or the kill signal is set, at which point it exits gracefully.
21    pub(crate) fn new(
22        id: u8,
23        receiver: Arc<Mutex<Receiver<Job>>>,
24        kill_signal: Arc<AtomicBool>,
25        job_count: Arc<AtomicUsize>,
26    ) -> Self {
27        let thread = std::thread::spawn(move || {
28            loop {
29                // Check kill signal before trying to receive
30                if kill_signal.load(Ordering::Relaxed) {
31                    #[cfg(feature = "log")]
32                    {
33                        debug!("Worker {id} received kill signal; shutting down;");
34                    }
35                    break;
36                }
37
38                // Use recv_timeout to periodically check kill signal
39                let job_msg = receiver
40                    .lock()
41                    .unwrap()
42                    .recv_timeout(Duration::from_millis(100));
43
44                match job_msg {
45                    Ok(job) => {
46                        #[cfg(feature = "log")]
47                        {
48                            debug!("Worker {id} got a job; executing.");
49                        }
50                        job();
51                        job_count.fetch_sub(1, Ordering::Relaxed);
52
53                        // Check kill signal after job execution
54                        if kill_signal.load(Ordering::Relaxed) {
55                            #[cfg(feature = "log")]
56                            {
57                                debug!(
58                                    "Worker {id} received kill signal after job; shutting down;"
59                                );
60                            }
61                            break;
62                        }
63                    }
64                    Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
65                        // Timeout - loop back to check kill signal
66                        continue;
67                    }
68                    Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
69                        #[cfg(feature = "log")]
70                        {
71                            debug!("Worker {id} disconnected; shutting down;");
72                        }
73                        break;
74                    }
75                }
76            }
77        });
78
79        Self {
80            id,
81            thread: Some(thread),
82        }
83    }
84
85    /// get id of the worker
86    #[inline]
87    pub fn get_id(&self) -> u8 {
88        self.id
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use std::sync::mpsc;
96    use std::time::Duration;
97
98    #[test]
99    fn test_worker_creation() {
100        let (sender, receiver) = mpsc::channel::<Job>();
101        let receiver = Arc::new(Mutex::new(receiver));
102        let kill_signal = Arc::new(AtomicBool::new(false));
103        let job_count = Arc::new(AtomicUsize::new(0));
104
105        let worker = Worker::new(
106            1,
107            Arc::clone(&receiver),
108            Arc::clone(&kill_signal),
109            Arc::clone(&job_count),
110        );
111
112        assert_eq!(worker.id, 1);
113        assert!(worker.thread.is_some());
114
115        // Clean up
116        drop(sender);
117        worker.thread.unwrap().join().unwrap();
118    }
119
120    #[test]
121    fn test_worker_executes_job() {
122        let (sender, receiver) = mpsc::channel::<Job>();
123        let receiver = Arc::new(Mutex::new(receiver));
124        let kill_signal = Arc::new(AtomicBool::new(false));
125        let job_count = Arc::new(AtomicUsize::new(0));
126
127        let executed = Arc::new(Mutex::new(false));
128        let executed_clone = Arc::clone(&executed);
129
130        let worker = Worker::new(
131            2,
132            Arc::clone(&receiver),
133            Arc::clone(&kill_signal),
134            Arc::clone(&job_count),
135        );
136
137        // Send a job
138        sender
139            .send(Box::new(move || {
140                *executed_clone.lock().unwrap() = true;
141            }))
142            .unwrap();
143
144        // Give worker time to execute
145        std::thread::sleep(Duration::from_millis(200));
146
147        // Verify job was executed
148        assert!(*executed.lock().unwrap());
149
150        // Clean up
151        drop(sender);
152        worker.thread.unwrap().join().unwrap();
153    }
154
155    #[test]
156    fn test_worker_shutdown_on_channel_close() {
157        let (sender, receiver) = mpsc::channel::<Job>();
158        let receiver = Arc::new(Mutex::new(receiver));
159        let kill_signal = Arc::new(AtomicBool::new(false));
160        let job_count = Arc::new(AtomicUsize::new(0));
161
162        let worker = Worker::new(
163            3,
164            Arc::clone(&receiver),
165            Arc::clone(&kill_signal),
166            Arc::clone(&job_count),
167        );
168
169        // Close channel by dropping sender
170        drop(sender);
171
172        // Worker thread should exit gracefully
173        let result = worker.thread.unwrap().join();
174        assert!(result.is_ok());
175    }
176
177    #[test]
178    fn test_worker_shutdown_on_kill_signal() {
179        let (sender, receiver) = mpsc::channel::<Job>();
180        let receiver = Arc::new(Mutex::new(receiver));
181        let kill_signal = Arc::new(AtomicBool::new(false));
182        let job_count = Arc::new(AtomicUsize::new(0));
183
184        let worker = Worker::new(
185            4,
186            Arc::clone(&receiver),
187            Arc::clone(&kill_signal),
188            Arc::clone(&job_count),
189        );
190
191        // Give worker time to start
192        std::thread::sleep(Duration::from_millis(50));
193
194        // Set kill signal
195        kill_signal.store(true, Ordering::Relaxed);
196
197        // Worker should exit within a reasonable time (< 200ms since it checks every 100ms)
198        let result = worker.thread.unwrap().join();
199        assert!(result.is_ok());
200
201        // Channel should still be open (we didn't drop sender)
202        drop(sender);
203    }
204
205    #[test]
206    fn test_worker_stops_after_current_job() {
207        let (sender, receiver) = mpsc::channel::<Job>();
208        let receiver = Arc::new(Mutex::new(receiver));
209        let kill_signal = Arc::new(AtomicBool::new(false));
210        let job_count = Arc::new(AtomicUsize::new(0));
211
212        let job_started = Arc::new(AtomicBool::new(false));
213        let job_completed = Arc::new(AtomicBool::new(false));
214
215        let job_started_clone = Arc::clone(&job_started);
216        let job_completed_clone = Arc::clone(&job_completed);
217        let kill_signal_clone = Arc::clone(&kill_signal);
218
219        let worker = Worker::new(
220            5,
221            Arc::clone(&receiver),
222            Arc::clone(&kill_signal),
223            Arc::clone(&job_count),
224        );
225
226        // Send a long-running job
227        sender
228            .send(Box::new(move || {
229                job_started_clone.store(true, Ordering::Relaxed);
230                std::thread::sleep(Duration::from_millis(100));
231                job_completed_clone.store(true, Ordering::Relaxed);
232            }))
233            .unwrap();
234
235        // Wait for job to start
236        std::thread::sleep(Duration::from_millis(50));
237        assert!(job_started.load(Ordering::Relaxed));
238
239        // Signal kill while job is running
240        kill_signal_clone.store(true, Ordering::Relaxed);
241
242        // Wait for worker to finish
243        worker.thread.unwrap().join().unwrap();
244
245        // Job should have completed before worker stopped
246        assert!(job_completed.load(Ordering::Relaxed));
247
248        drop(sender);
249    }
250}