jlizard_simple_threadpool/
worker.rs

1//! Worker model for concurrent jobs handling
2use crate::common::Job;
3#[cfg(feature = "log")]
4use log::debug;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::mpsc::Receiver;
7use std::sync::{Arc, Mutex};
8use std::thread::JoinHandle;
9use std::time::Duration;
10
11pub struct Worker {
12    pub(super) id: u8,
13    pub(super) thread: Option<JoinHandle<()>>,
14}
15
16impl Worker {
17    /// Creates a new worker that spawns a thread to process jobs from the shared receiver.
18    ///
19    /// The worker continuously receives jobs from the channel until the sender is dropped
20    /// or the kill signal is set, at which point it exits gracefully.
21    pub(crate) fn new(
22        id: u8,
23        receiver: Arc<Mutex<Receiver<Job>>>,
24        kill_signal: Arc<AtomicBool>,
25    ) -> Self {
26        let thread = std::thread::spawn(move || {
27            loop {
28                // Check kill signal before trying to receive
29                if kill_signal.load(Ordering::Relaxed) {
30                    #[cfg(feature = "log")]
31                    {
32                        debug!("Worker {id} received kill signal; shutting down;");
33                    }
34                    break;
35                }
36
37                // Use recv_timeout to periodically check kill signal
38                let job_msg = receiver
39                    .lock()
40                    .unwrap()
41                    .recv_timeout(Duration::from_millis(100));
42
43                match job_msg {
44                    Ok(job) => {
45                        #[cfg(feature = "log")]
46                        {
47                            debug!("Worker {id} got a job; executing.");
48                        }
49                        job();
50
51                        // Check kill signal after job execution
52                        if kill_signal.load(Ordering::Relaxed) {
53                            #[cfg(feature = "log")]
54                            {
55                                debug!(
56                                    "Worker {id} received kill signal after job; shutting down;"
57                                );
58                            }
59                            break;
60                        }
61                    }
62                    Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
63                        // Timeout - loop back to check kill signal
64                        continue;
65                    }
66                    Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
67                        #[cfg(feature = "log")]
68                        {
69                            debug!("Worker {id} disconnected; shutting down;");
70                        }
71                        break;
72                    }
73                }
74            }
75        });
76
77        Self {
78            id,
79            thread: Some(thread),
80        }
81    }
82
83    /// get id of the worker
84    #[inline]
85    pub fn get_id(&self) -> u8 {
86        self.id
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use std::sync::mpsc;
94    use std::time::Duration;
95
96    #[test]
97    fn test_worker_creation() {
98        let (sender, receiver) = mpsc::channel::<Job>();
99        let receiver = Arc::new(Mutex::new(receiver));
100        let kill_signal = Arc::new(AtomicBool::new(false));
101
102        let worker = Worker::new(1, Arc::clone(&receiver), Arc::clone(&kill_signal));
103
104        assert_eq!(worker.id, 1);
105        assert!(worker.thread.is_some());
106
107        // Clean up
108        drop(sender);
109        worker.thread.unwrap().join().unwrap();
110    }
111
112    #[test]
113    fn test_worker_executes_job() {
114        let (sender, receiver) = mpsc::channel::<Job>();
115        let receiver = Arc::new(Mutex::new(receiver));
116        let kill_signal = Arc::new(AtomicBool::new(false));
117
118        let executed = Arc::new(Mutex::new(false));
119        let executed_clone = Arc::clone(&executed);
120
121        let worker = Worker::new(2, Arc::clone(&receiver), Arc::clone(&kill_signal));
122
123        // Send a job
124        sender
125            .send(Box::new(move || {
126                *executed_clone.lock().unwrap() = true;
127            }))
128            .unwrap();
129
130        // Give worker time to execute
131        std::thread::sleep(Duration::from_millis(200));
132
133        // Verify job was executed
134        assert!(*executed.lock().unwrap());
135
136        // Clean up
137        drop(sender);
138        worker.thread.unwrap().join().unwrap();
139    }
140
141    #[test]
142    fn test_worker_shutdown_on_channel_close() {
143        let (sender, receiver) = mpsc::channel::<Job>();
144        let receiver = Arc::new(Mutex::new(receiver));
145        let kill_signal = Arc::new(AtomicBool::new(false));
146
147        let worker = Worker::new(3, Arc::clone(&receiver), Arc::clone(&kill_signal));
148
149        // Close channel by dropping sender
150        drop(sender);
151
152        // Worker thread should exit gracefully
153        let result = worker.thread.unwrap().join();
154        assert!(result.is_ok());
155    }
156
157    #[test]
158    fn test_worker_shutdown_on_kill_signal() {
159        let (sender, receiver) = mpsc::channel::<Job>();
160        let receiver = Arc::new(Mutex::new(receiver));
161        let kill_signal = Arc::new(AtomicBool::new(false));
162
163        let worker = Worker::new(4, Arc::clone(&receiver), Arc::clone(&kill_signal));
164
165        // Give worker time to start
166        std::thread::sleep(Duration::from_millis(50));
167
168        // Set kill signal
169        kill_signal.store(true, Ordering::Relaxed);
170
171        // Worker should exit within a reasonable time (< 200ms since it checks every 100ms)
172        let result = worker.thread.unwrap().join();
173        assert!(result.is_ok());
174
175        // Channel should still be open (we didn't drop sender)
176        drop(sender);
177    }
178
179    #[test]
180    fn test_worker_stops_after_current_job() {
181        let (sender, receiver) = mpsc::channel::<Job>();
182        let receiver = Arc::new(Mutex::new(receiver));
183        let kill_signal = Arc::new(AtomicBool::new(false));
184
185        let job_started = Arc::new(AtomicBool::new(false));
186        let job_completed = Arc::new(AtomicBool::new(false));
187
188        let job_started_clone = Arc::clone(&job_started);
189        let job_completed_clone = Arc::clone(&job_completed);
190        let kill_signal_clone = Arc::clone(&kill_signal);
191
192        let worker = Worker::new(5, Arc::clone(&receiver), Arc::clone(&kill_signal));
193
194        // Send a long-running job
195        sender
196            .send(Box::new(move || {
197                job_started_clone.store(true, Ordering::Relaxed);
198                std::thread::sleep(Duration::from_millis(100));
199                job_completed_clone.store(true, Ordering::Relaxed);
200            }))
201            .unwrap();
202
203        // Wait for job to start
204        std::thread::sleep(Duration::from_millis(50));
205        assert!(job_started.load(Ordering::Relaxed));
206
207        // Signal kill while job is running
208        kill_signal_clone.store(true, Ordering::Relaxed);
209
210        // Wait for worker to finish
211        worker.thread.unwrap().join().unwrap();
212
213        // Job should have completed before worker stopped
214        assert!(job_completed.load(Ordering::Relaxed));
215
216        drop(sender);
217    }
218}