use std::{
fmt::Debug,
sync::{Arc, Mutex, OnceLock},
};
use std::{sync::mpsc, thread};
struct FixedThreadPool {
inner: Arc<ThreadPool>,
}
impl FixedThreadPool {
pub(self) fn instance(num_workers: usize) -> &'static FixedThreadPool {
static INSTANCE: OnceLock<FixedThreadPool> = OnceLock::new();
INSTANCE.get_or_init(|| FixedThreadPool {
inner: Arc::new(ThreadPool::new(num_workers)),
})
}
}
pub fn get_thread_pool(num_workers: usize) -> Arc<ThreadPool> {
Arc::clone(&FixedThreadPool::instance(num_workers).inner)
}
pub struct WorkResult<T> {
receiver: mpsc::Receiver<T>,
}
impl<T> WorkResult<T> {
pub fn new(rx: mpsc::Receiver<T>) -> Self {
WorkResult { receiver: rx }
}
pub fn result(&self) -> T {
self.receiver.recv().unwrap()
}
}
pub struct ThreadPool {
sender: mpsc::Sender<Message>,
workers: Vec<Worker>,
}
impl ThreadPool {
pub fn new(size: usize) -> Self {
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
ThreadPool {
sender,
workers: (0..size)
.map(|id| Worker::new(id, Arc::clone(&receiver)))
.collect(),
}
}
pub fn num_workers(&self) -> usize {
self.workers.len()
}
pub fn is_alive(&self) -> bool {
self.workers.iter().any(|worker| worker.is_alive())
}
pub fn submit<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.send(Message::Work(job)).unwrap();
}
pub fn submit_with_result<F, T>(&self, f: F) -> WorkResult<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = mpsc::sync_channel(1);
let job = Box::new(move || tx.send(f()).unwrap());
self.sender.send(Message::Work(job)).unwrap();
WorkResult { receiver: rx }
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
for _ in self.workers.iter() {
self.sender.send(Message::Terminate).unwrap();
}
for worker in self.workers.iter_mut() {
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
assert!(!self.is_alive());
}
}
type Job = Box<dyn FnOnce() + Send + 'static>;
enum Message {
Work(Job),
Terminate,
}
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
Worker {
id,
thread: Some(thread::spawn(move || {
loop {
let message = receiver.lock().unwrap().recv().unwrap();
match message {
Message::Work(job) => job(),
Message::Terminate => break,
}
}
})),
}
}
pub fn is_alive(&self) -> bool {
self.thread.is_some()
}
}
impl Debug for Worker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Worker")
.field("id", &self.id)
.field("is_alive", &self.is_alive())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::WaitGroup;
use std::time::{Duration, Instant};
#[test]
fn test_thread_pool_creation() {
let pool = ThreadPool::new(4);
assert!(pool.is_alive());
}
#[test]
fn test_basic_job_execution() {
let pool = ThreadPool::new(4);
let counter = Arc::new(Mutex::new(0));
for _ in 0..8 {
let counter = Arc::clone(&counter);
pool.submit(move || {
let mut num = counter.lock().unwrap();
*num += 1;
});
}
thread::sleep(Duration::from_secs(1));
assert_eq!(*counter.lock().unwrap(), 8);
}
#[test]
fn test_thread_pool() {
let pool = ThreadPool::new(4);
for i in 0..8 {
pool.submit(move || {
let start_time = std::time::SystemTime::now();
println!("Job {} started.", i);
thread::sleep(Duration::from_secs(1));
println!("Job {} finished in {:?}.", i, start_time.elapsed().unwrap());
});
}
}
#[test]
fn test_job_order() {
let pool = ThreadPool::new(2);
let results = Arc::new(Mutex::new(vec![]));
for i in 0..5 {
let results = Arc::clone(&results);
pool.submit(move || {
results.lock().unwrap().push(i);
});
}
thread::sleep(Duration::from_secs(1));
let mut results = results.lock().unwrap();
results.sort(); assert_eq!(*results, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_thread_pool_process() {
let pool = ThreadPool::new(4);
let results = pool.submit_with_result(|| {
let start_time = std::time::SystemTime::now();
println!("Job started.");
thread::sleep(Duration::from_secs(2));
println!("Job finished in {:?}.", start_time.elapsed().unwrap());
42
});
let result = results.result();
assert_eq!(result, 42);
}
#[test]
fn test_max_concurrent_jobs() {
let pool = ThreadPool::new(4);
let (tx, rx) = mpsc::channel();
let num_jobs = 20;
let start_time = Instant::now();
for i in 0..num_jobs {
let tx = tx.clone();
pool.submit(move || {
thread::sleep(Duration::from_millis(100));
tx.send(i).unwrap();
});
}
let mut results = vec![];
for _ in 0..num_jobs {
results.push(rx.recv().unwrap());
}
let elapsed = start_time.elapsed();
assert!(elapsed < Duration::from_secs(3));
assert_eq!(results.len(), num_jobs);
assert!(results.iter().all(|&x| x < num_jobs));
}
#[test]
fn tests_thread_pool_submit_with_result_returns_correct_order() {
let pool = ThreadPool::new(5);
let num_jobs = 10;
let mut work_results = vec![];
for i in 0..num_jobs {
let work_result = pool.submit_with_result(move || {
thread::sleep(Duration::from_millis(50 * (num_jobs - i) as u64));
i * i
});
work_results.push(work_result);
}
for (i, work_result) in work_results.into_iter().enumerate() {
let result = work_result.result();
assert_eq!(result, i * i);
}
}
#[test]
fn test_wait_group() {
let pool = ThreadPool::new(4);
let wg = WaitGroup::new();
let num_tasks = 10;
let total = Arc::new(Mutex::new(0));
for _ in 0..num_tasks {
let guard = wg.guard();
let total = Arc::clone(&total);
pool.submit(move || {
thread::sleep(Duration::from_millis(100));
let mut num = total.lock().unwrap();
*num += 1;
drop(guard);
});
}
{
let total = total.lock().unwrap();
assert_ne!(*total, num_tasks);
}
let total_tasks_waited_for = wg.wait();
let total = total.lock().unwrap();
assert_eq!(*total, num_tasks);
assert_eq!(total_tasks_waited_for, num_tasks);
}
#[test]
fn test_wait_group_zero_tasks() {
let wg = WaitGroup::new();
let total_tasks_waited_for = wg.wait();
assert_eq!(total_tasks_waited_for, 0);
}
}