use std::sync::{mpsc, Arc, Mutex};
use std::thread::{self, JoinHandle};
use crate::monitor::{Command, MonitorSender};
pub type ResultsSender<T> = mpsc::SyncSender<ResultPacket<T>>;
pub type ResultsReceiver<T> = mpsc::Receiver<ResultPacket<T>>;
pub type JobsSender<T> = mpsc::SyncSender<Message<T>>;
pub type JobsReceiver<T> = mpsc::Receiver<Message<T>>;
pub struct ThreadPool<T> {
pub jobs_tx: JobsSender<T>,
workers: Vec<Worker>,
}
pub enum Message<T> {
Job(JobPacket<T>),
Terminate,
}
pub struct JobPacket<T> {
pub job: fn() -> T,
pub return_tx: ResultsSender<T>,
}
pub struct ResultPacket<T> {
pub result: T,
}
impl<T: Send + 'static> ThreadPool<T> {
pub fn new(size: usize, monitor_tx: MonitorSender) -> Self {
assert!(size > 0,);
let (jobs_tx, jobs_rx) = mpsc::sync_channel(100);
let jobs_rx = Arc::new(Mutex::new(jobs_rx));
let workers = (0..size)
.map(|_| Worker::new(jobs_rx.clone(), monitor_tx.clone()))
.collect();
ThreadPool { workers, jobs_tx }
}
}
struct Worker {
handle: Option<JoinHandle<()>>,
}
impl Worker {
pub fn new<T: Send + 'static>(
jobs_rx: Arc<Mutex<JobsReceiver<T>>>, monitor_tx: MonitorSender,
) -> Self {
let handle =
Some(thread::spawn(move || Self::listen(jobs_rx, monitor_tx)));
Self { handle }
}
fn listen<T: Send + 'static>(
jobs_rx: Arc<Mutex<JobsReceiver<T>>>, monitor_tx: MonitorSender,
) {
loop {
let message = jobs_rx.lock().unwrap().recv().unwrap();
match message {
Message::Terminate => break,
Message::Job(JobPacket { job, return_tx }) => {
return_tx.send(ResultPacket { result: job() }).unwrap();
monitor_tx.send(Command::Refresh).unwrap();
},
}
}
}
}
impl<T> Drop for ThreadPool<T> {
fn drop(&mut self) {
for _ in &self.workers {
self.jobs_tx.send(Message::Terminate).unwrap();
}
for worker in &mut self.workers {
if let Some(thread) = worker.handle.take() {
thread.join().unwrap();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn workers_can_terminate() {
let (jobs_tx, jobs_rx): (JobsSender<String>, JobsReceiver<String>) =
mpsc::sync_channel(10);
let jobs_rx = Arc::new(Mutex::new(jobs_rx));
let (monitor_tx, _monitor_rx) = mpsc::sync_channel(10);
let worker = Worker::new(jobs_rx, monitor_tx);
jobs_tx.send(Message::Terminate).unwrap();
worker.handle.unwrap().join().unwrap();
}
#[test]
fn workers_return_correct_values() {
let (jobs_tx, jobs_rx) = mpsc::sync_channel(10);
let (results_tx, results_rx) = mpsc::sync_channel(10);
let jobs_rx = Arc::new(Mutex::new(jobs_rx));
let (monitor_tx, _monitor_rx) = mpsc::sync_channel(10);
let worker = Worker::new(jobs_rx, monitor_tx);
jobs_tx
.send(Message::Job(JobPacket {
job: || String::from("the test worked :)"),
return_tx: results_tx,
}))
.unwrap();
let result = results_rx.recv().unwrap().result;
assert_eq!(result, "the test worked :)");
jobs_tx.send(Message::Terminate).unwrap();
worker.handle.unwrap().join().unwrap();
}
}