use crate::common::Job;
#[cfg(feature = "log")]
use log::debug;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::Receiver;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use std::time::Duration;
pub struct Worker {
pub(super) id: u8,
pub(super) thread: Option<JoinHandle<()>>,
}
impl Worker {
pub(crate) fn new(
id: u8,
receiver: Arc<Mutex<Receiver<Job>>>,
kill_signal: Arc<AtomicBool>,
job_count: Arc<AtomicUsize>,
) -> Self {
let thread = std::thread::spawn(move || {
loop {
if kill_signal.load(Ordering::Relaxed) {
#[cfg(feature = "log")]
{
debug!("Worker {id} received kill signal; shutting down;");
}
break;
}
let job_msg = receiver
.lock()
.unwrap()
.recv_timeout(Duration::from_millis(100));
match job_msg {
Ok(job) => {
#[cfg(feature = "log")]
{
debug!("Worker {id} got a job; executing.");
}
job();
job_count.fetch_sub(1, Ordering::Relaxed);
if kill_signal.load(Ordering::Relaxed) {
#[cfg(feature = "log")]
{
debug!(
"Worker {id} received kill signal after job; shutting down;"
);
}
break;
}
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
continue;
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
#[cfg(feature = "log")]
{
debug!("Worker {id} disconnected; shutting down;");
}
break;
}
}
}
});
Self {
id,
thread: Some(thread),
}
}
#[inline]
pub fn get_id(&self) -> u8 {
self.id
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc;
use std::time::Duration;
#[test]
fn test_worker_creation() {
let (sender, receiver) = mpsc::channel::<Job>();
let receiver = Arc::new(Mutex::new(receiver));
let kill_signal = Arc::new(AtomicBool::new(false));
let job_count = Arc::new(AtomicUsize::new(0));
let worker = Worker::new(
1,
Arc::clone(&receiver),
Arc::clone(&kill_signal),
Arc::clone(&job_count),
);
assert_eq!(worker.id, 1);
assert!(worker.thread.is_some());
drop(sender);
worker.thread.unwrap().join().unwrap();
}
#[test]
fn test_worker_executes_job() {
let (sender, receiver) = mpsc::channel::<Job>();
let receiver = Arc::new(Mutex::new(receiver));
let kill_signal = Arc::new(AtomicBool::new(false));
let job_count = Arc::new(AtomicUsize::new(0));
let executed = Arc::new(Mutex::new(false));
let executed_clone = Arc::clone(&executed);
let worker = Worker::new(
2,
Arc::clone(&receiver),
Arc::clone(&kill_signal),
Arc::clone(&job_count),
);
sender
.send(Box::new(move || {
*executed_clone.lock().unwrap() = true;
}))
.unwrap();
std::thread::sleep(Duration::from_millis(200));
assert!(*executed.lock().unwrap());
drop(sender);
worker.thread.unwrap().join().unwrap();
}
#[test]
fn test_worker_shutdown_on_channel_close() {
let (sender, receiver) = mpsc::channel::<Job>();
let receiver = Arc::new(Mutex::new(receiver));
let kill_signal = Arc::new(AtomicBool::new(false));
let job_count = Arc::new(AtomicUsize::new(0));
let worker = Worker::new(
3,
Arc::clone(&receiver),
Arc::clone(&kill_signal),
Arc::clone(&job_count),
);
drop(sender);
let result = worker.thread.unwrap().join();
assert!(result.is_ok());
}
#[test]
fn test_worker_shutdown_on_kill_signal() {
let (sender, receiver) = mpsc::channel::<Job>();
let receiver = Arc::new(Mutex::new(receiver));
let kill_signal = Arc::new(AtomicBool::new(false));
let job_count = Arc::new(AtomicUsize::new(0));
let worker = Worker::new(
4,
Arc::clone(&receiver),
Arc::clone(&kill_signal),
Arc::clone(&job_count),
);
std::thread::sleep(Duration::from_millis(50));
kill_signal.store(true, Ordering::Relaxed);
let result = worker.thread.unwrap().join();
assert!(result.is_ok());
drop(sender);
}
#[test]
fn test_worker_stops_after_current_job() {
let (sender, receiver) = mpsc::channel::<Job>();
let receiver = Arc::new(Mutex::new(receiver));
let kill_signal = Arc::new(AtomicBool::new(false));
let job_count = Arc::new(AtomicUsize::new(0));
let job_started = Arc::new(AtomicBool::new(false));
let job_completed = Arc::new(AtomicBool::new(false));
let job_started_clone = Arc::clone(&job_started);
let job_completed_clone = Arc::clone(&job_completed);
let kill_signal_clone = Arc::clone(&kill_signal);
let worker = Worker::new(
5,
Arc::clone(&receiver),
Arc::clone(&kill_signal),
Arc::clone(&job_count),
);
sender
.send(Box::new(move || {
job_started_clone.store(true, Ordering::Relaxed);
std::thread::sleep(Duration::from_millis(100));
job_completed_clone.store(true, Ordering::Relaxed);
}))
.unwrap();
std::thread::sleep(Duration::from_millis(50));
assert!(job_started.load(Ordering::Relaxed));
kill_signal_clone.store(true, Ordering::Relaxed);
worker.thread.unwrap().join().unwrap();
assert!(job_completed.load(Ordering::Relaxed));
drop(sender);
}
}