job_pool/
pool.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use std::sync::mpsc::SendError;
use std::sync::{mpsc, Arc, Condvar, Mutex};
use crate::worker::{Job, Worker};
use crate::{PoolConfig, Result, Semaphore};

pub enum SenderWrapper<T> {
    Bounded(mpsc::SyncSender<T>),
    Unbounded(mpsc::Sender<T>),
}

impl<T> SenderWrapper<T> {
    fn send(&self, t: T) -> std::result::Result<(),SendError<T>> {
        match self {
            SenderWrapper::Bounded(b) => b.send(t),
            SenderWrapper::Unbounded(u) => u.send(t),
        }
    }
}

/// Thread Pool
///
/// A thread pool coordinates a group of threads to run
/// taks in parallel.
///
/// # Example
/// ```
/// use job_pool::ThreadPool;
///
/// let pool = ThreadPool::with_size(32).expect("Error creating pool");
/// pool.execute(|| println!("Hello world!"));
/// ```
pub struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<SenderWrapper<Box<dyn Job>>>,
    semaphore: Semaphore,
    max_jobs: Option<u16>,
}

impl ThreadPool {
    /// Create a new ThreadPool.
    pub fn new(config: PoolConfig) -> Result<ThreadPool> {
        let size = config.n_workers as usize;
        let (sender,receiver) =
            if let Some(max) = config.incoming_buf_size {
                let (sender,receiver) = mpsc::sync_channel(max as usize);
                let sender = SenderWrapper::Bounded(sender);
                (sender,receiver)
            } else {
                let (sender,receiver) = mpsc::channel();
                let sender = SenderWrapper::Unbounded(sender);
                (sender,receiver)
            };
        let receiver = Arc::new(Mutex::new(receiver));
        let semaphore = Arc::new((Mutex::new(0),Condvar::new()));
        let mut workers = Vec::with_capacity(size);
        for _ in 0..size {
            let worker = Worker::new(receiver.clone(),semaphore.clone());
            workers.push(worker);
        }
        Ok(ThreadPool {
            workers, semaphore,
            sender:Some(sender),
            max_jobs: config.max_jobs
        })
    }
    /// Create a [ThreadPool] with the default [configuration](PoolConfig)
    #[inline]
    pub fn with_default_config() -> Result<Self> {
        let conf = PoolConfig::builder()
                              .build().map_err(|err| err.to_string())?;
        Self::new(conf)
    }
    /// Create a [ThreadPool] with a given size
    #[inline]
    pub fn with_size(size: u16) -> Result<Self> {
        let conf = PoolConfig::builder()
                              .n_workers(size)
                              .build().map_err(|err| err.to_string())?;
        Self::new(conf)
    }
    pub fn execute(&self, job: impl Job) {
        fn _execute(slf: &ThreadPool, job: Box<dyn Job>) {
            {
                let (lock,cvar) = &*slf.semaphore;
                let mut counter = lock.lock().unwrap();
                if let Some(max) = slf.max_jobs {
                    counter = cvar.wait_while(counter, |n| *n >= max).unwrap();
                }
                *counter += 1;
            }
            slf.sender
                .as_ref()
                .unwrap()
                .send(job)
                .unwrap();
        }
         _execute(self, Box::new(job));
    }
    /// Waits for all the jobs in the pool to finish
    pub fn join(&self) {
        let (lock,condv) = &*self.semaphore;
        let counter = lock.lock().unwrap();
        let _guard = condv.wait_while(counter, |n| *n > 0).unwrap();
    }
}

impl Drop for ThreadPool  {
    fn drop(&mut self) {
        drop(self.sender.take());
        self.workers
            .iter_mut()
            .for_each(Worker::shutdown);
    }
}