jlizard-simple-threadpool 0.3.0

A simple, lightweight threadpool implementation
Documentation
//! Worker model for concurrent jobs handling
use crate::common::Job;
#[cfg(feature = "log")]
use log::debug;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::Receiver;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use std::time::Duration;

pub struct Worker {
    pub(super) id: u8,
    pub(super) thread: Option<JoinHandle<()>>,
}

impl Worker {
    /// Creates a new worker that spawns a thread to process jobs from the shared receiver.
    ///
    /// The worker continuously receives jobs from the channel until the sender is dropped
    /// or the kill signal is set, at which point it exits gracefully.
    pub(crate) fn new(
        id: u8,
        receiver: Arc<Mutex<Receiver<Job>>>,
        kill_signal: Arc<AtomicBool>,
        job_count: Arc<AtomicUsize>,
    ) -> Self {
        let thread = std::thread::spawn(move || {
            loop {
                // Check kill signal before trying to receive
                if kill_signal.load(Ordering::Relaxed) {
                    #[cfg(feature = "log")]
                    {
                        debug!("Worker {id} received kill signal; shutting down;");
                    }
                    break;
                }

                // Use recv_timeout to periodically check kill signal
                let job_msg = receiver
                    .lock()
                    .unwrap()
                    .recv_timeout(Duration::from_millis(100));

                match job_msg {
                    Ok(job) => {
                        #[cfg(feature = "log")]
                        {
                            debug!("Worker {id} got a job; executing.");
                        }
                        job();
                        job_count.fetch_sub(1, Ordering::Relaxed);

                        // Check kill signal after job execution
                        if kill_signal.load(Ordering::Relaxed) {
                            #[cfg(feature = "log")]
                            {
                                debug!(
                                    "Worker {id} received kill signal after job; shutting down;"
                                );
                            }
                            break;
                        }
                    }
                    Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
                        // Timeout - loop back to check kill signal
                        continue;
                    }
                    Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
                        #[cfg(feature = "log")]
                        {
                            debug!("Worker {id} disconnected; shutting down;");
                        }
                        break;
                    }
                }
            }
        });

        Self {
            id,
            thread: Some(thread),
        }
    }

    /// get id of the worker
    #[inline]
    pub fn get_id(&self) -> u8 {
        self.id
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::mpsc;
    use std::time::Duration;

    #[test]
    fn test_worker_creation() {
        let (sender, receiver) = mpsc::channel::<Job>();
        let receiver = Arc::new(Mutex::new(receiver));
        let kill_signal = Arc::new(AtomicBool::new(false));
        let job_count = Arc::new(AtomicUsize::new(0));

        let worker = Worker::new(
            1,
            Arc::clone(&receiver),
            Arc::clone(&kill_signal),
            Arc::clone(&job_count),
        );

        assert_eq!(worker.id, 1);
        assert!(worker.thread.is_some());

        // Clean up
        drop(sender);
        worker.thread.unwrap().join().unwrap();
    }

    #[test]
    fn test_worker_executes_job() {
        let (sender, receiver) = mpsc::channel::<Job>();
        let receiver = Arc::new(Mutex::new(receiver));
        let kill_signal = Arc::new(AtomicBool::new(false));
        let job_count = Arc::new(AtomicUsize::new(0));

        let executed = Arc::new(Mutex::new(false));
        let executed_clone = Arc::clone(&executed);

        let worker = Worker::new(
            2,
            Arc::clone(&receiver),
            Arc::clone(&kill_signal),
            Arc::clone(&job_count),
        );

        // Send a job
        sender
            .send(Box::new(move || {
                *executed_clone.lock().unwrap() = true;
            }))
            .unwrap();

        // Give worker time to execute
        std::thread::sleep(Duration::from_millis(200));

        // Verify job was executed
        assert!(*executed.lock().unwrap());

        // Clean up
        drop(sender);
        worker.thread.unwrap().join().unwrap();
    }

    #[test]
    fn test_worker_shutdown_on_channel_close() {
        let (sender, receiver) = mpsc::channel::<Job>();
        let receiver = Arc::new(Mutex::new(receiver));
        let kill_signal = Arc::new(AtomicBool::new(false));
        let job_count = Arc::new(AtomicUsize::new(0));

        let worker = Worker::new(
            3,
            Arc::clone(&receiver),
            Arc::clone(&kill_signal),
            Arc::clone(&job_count),
        );

        // Close channel by dropping sender
        drop(sender);

        // Worker thread should exit gracefully
        let result = worker.thread.unwrap().join();
        assert!(result.is_ok());
    }

    #[test]
    fn test_worker_shutdown_on_kill_signal() {
        let (sender, receiver) = mpsc::channel::<Job>();
        let receiver = Arc::new(Mutex::new(receiver));
        let kill_signal = Arc::new(AtomicBool::new(false));
        let job_count = Arc::new(AtomicUsize::new(0));

        let worker = Worker::new(
            4,
            Arc::clone(&receiver),
            Arc::clone(&kill_signal),
            Arc::clone(&job_count),
        );

        // Give worker time to start
        std::thread::sleep(Duration::from_millis(50));

        // Set kill signal
        kill_signal.store(true, Ordering::Relaxed);

        // Worker should exit within a reasonable time (< 200ms since it checks every 100ms)
        let result = worker.thread.unwrap().join();
        assert!(result.is_ok());

        // Channel should still be open (we didn't drop sender)
        drop(sender);
    }

    #[test]
    fn test_worker_stops_after_current_job() {
        let (sender, receiver) = mpsc::channel::<Job>();
        let receiver = Arc::new(Mutex::new(receiver));
        let kill_signal = Arc::new(AtomicBool::new(false));
        let job_count = Arc::new(AtomicUsize::new(0));

        let job_started = Arc::new(AtomicBool::new(false));
        let job_completed = Arc::new(AtomicBool::new(false));

        let job_started_clone = Arc::clone(&job_started);
        let job_completed_clone = Arc::clone(&job_completed);
        let kill_signal_clone = Arc::clone(&kill_signal);

        let worker = Worker::new(
            5,
            Arc::clone(&receiver),
            Arc::clone(&kill_signal),
            Arc::clone(&job_count),
        );

        // Send a long-running job
        sender
            .send(Box::new(move || {
                job_started_clone.store(true, Ordering::Relaxed);
                std::thread::sleep(Duration::from_millis(100));
                job_completed_clone.store(true, Ordering::Relaxed);
            }))
            .unwrap();

        // Wait for job to start
        std::thread::sleep(Duration::from_millis(50));
        assert!(job_started.load(Ordering::Relaxed));

        // Signal kill while job is running
        kill_signal_clone.store(true, Ordering::Relaxed);

        // Wait for worker to finish
        worker.thread.unwrap().join().unwrap();

        // Job should have completed before worker stopped
        assert!(job_completed.load(Ordering::Relaxed));

        drop(sender);
    }
}