executor_service/
lib.rs

1use std::fmt::{Display, Formatter};
2use std::marker::PhantomData;
3use std::thread;
4use std::sync::mpsc::{channel, sync_channel, Sender, SyncSender, Receiver, SendError, RecvError};
5use std::sync::{Arc, Mutex, Condvar};
6use log::trace;
7
8pub type Runnable<T> = dyn Send + 'static + FnOnce() -> T;
9
10///
11/// Maximum number of threads that can be requested for a pool
12/// This constant does not play an overarching role. In other words,
13/// Of you have multiple thread pools and if the system supports,
14/// you might have a total thread count of more than [MAX_THREAD_COUNT] for
15/// the entire application
16pub const MAX_THREAD_COUNT: u32 = 150;
17
18///
19/// Default number if thread for a cached pool
20///
21pub const DEFAULT_INITIAL_CACHED_THREAD_COUNT: u32 = 10;
22
23#[derive(Debug, Clone)]
24pub enum PoolType {
25  Cached,
26  Fixed,
27}
28
29#[derive(Debug)]
30pub enum ExecutorServiceError {
31  ParameterError(String),
32  IOError(std::io::Error),
33  ProcessingError,
34  ResultReceptionError,
35}
36
37impl<T> From<SendError<T>> for ExecutorServiceError {
38  fn from(_: SendError<T>) -> Self {
39    ExecutorServiceError::ProcessingError
40  }
41}
42
43impl From<RecvError> for ExecutorServiceError {
44  fn from(_: RecvError) -> Self {
45    ExecutorServiceError::ResultReceptionError
46  }
47}
48
49impl From<std::io::Error> for ExecutorServiceError {
50  fn from(value: std::io::Error) -> Self {
51    ExecutorServiceError::IOError(value)
52  }
53}
54
55impl Display for ExecutorServiceError {
56  fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57    match self {
58      ExecutorServiceError::ParameterError(message) => write!(f, "{:}: {:}", "ParameterError", message.as_str()),
59      ExecutorServiceError::IOError(io_error) => write!(f, "{:}: {:}", "IOError", io_error),
60      ExecutorServiceError::ResultReceptionError => write!(f, "{:}", "ResultReceptionError"),
61      ExecutorServiceError::ProcessingError => write!(f, "{:}", "ProcessingError"),
62    }
63  }
64}
65
66impl std::error::Error for ExecutorServiceError {}
67
68pub struct Future<T: Send + 'static> {
69  result_receiver: Receiver<T>,
70}
71
72
73impl<T: Send + 'static> Future<T> {
74  pub fn get(&self) -> Result<T, ExecutorServiceError> {
75    Ok(self.result_receiver.recv()?)
76  }
77}
78
79enum DispatcherEventType<F, T>
80  where F: FnOnce() -> T,
81        T: Send + 'static,
82        F: Send + 'static
83{
84  Execute(Option<SyncSender<T>>, F),
85  Quit,
86}
87
88enum EventType<F, T>
89  where F: FnOnce() -> T,
90        T: Send + 'static,
91        F: Send + 'static
92{
93  Execute(Option<SyncSender<T>>, Sender<Self>, F),
94  Quit,
95}
96
97
98impl<F: Send + 'static + FnOnce() -> T, T: Send + 'static> Display for DispatcherEventType<F, T> {
99  fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100    write!(f, "EventType::{:}",
101           match self {
102             Self::Execute(_, _) => "Execute",
103             Self::Quit => "Quit",
104           }
105    )
106  }
107}
108
109impl<F: Send + 'static + FnOnce() -> T, T: Send + 'static> Display for EventType<F, T> {
110  fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
111    write!(f, "EventType::{:}",
112           match self {
113             Self::Execute(_, _, _) => "Execute",
114             Self::Quit => "Quit",
115           }
116    )
117  }
118}
119
120///
121/// The executor service that allows tasks to be submitted/executed
122/// on the underlying thread pool.
123/// ```
124/// use executor_service::Executors;
125/// use std::thread::sleep;
126/// use core::time::Duration;
127///
128/// let mut executor_service = Executors::new_fixed_thread_pool(2).expect("Failed to create the thread pool");
129///
130/// let some_param = "Mr White";
131/// let res = executor_service.submit_sync(move || {
132///
133///   sleep(Duration::from_secs(5));
134///   println!("Hello {:}", some_param);
135///   println!("Long computation finished");
136///   2
137/// }).expect("Failed to submit function");
138///
139/// println!("Result: {:#?}", res);
140/// assert_eq!(res, 2);
141///```
142pub struct ExecutorService<F, T>
143  where F: FnOnce() -> T,
144        F: Send + 'static,
145        T: Send + 'static {
146  dispatcher: SyncSender<DispatcherEventType<F, T>>,
147  pool_type: PoolType,
148  thread_count: Arc<Mutex<u32>>,
149}
150
151impl<F, T> ExecutorService<F, T>
152  where F: FnOnce() -> T,
153        F: Send + 'static,
154        T: Send + 'static {
155  ///
156  /// Execute a function on the thread pool asynchronously with no return.
157  /// ```
158  /// use executor_service::Executors;
159  /// use std::thread::sleep;
160  /// use core::time::Duration;
161  /// use std::thread;
162  ///
163  /// let mut executor_service = Executors::new_fixed_thread_pool(2).expect("Failed to create the thread pool");
164  ///
165  /// let some_param = "Mr White";
166  /// let res = executor_service.execute(move || {
167  ///   sleep(Duration::from_secs(1));
168  ///   println!("Hello {:} from thread {:}", some_param, thread::current().name().unwrap());
169  /// }).expect("Failed to execute function");
170  ///
171  /// sleep(Duration::from_secs(3));
172  ///```
173  ///
174  pub fn execute(&mut self, fun: F) -> Result<(), ExecutorServiceError> {
175    Ok(self.dispatcher.send(DispatcherEventType::Execute(None, fun))?)
176  }
177
178  ///
179  /// Submit a function and wait for its result synchronously
180  /// ```
181  /// use executor_service::Executors;
182  /// use std::thread::sleep;
183  /// use core::time::Duration;
184  ///
185  /// let mut executor_service = Executors::new_cached_thread_pool(None).expect("Failed to create the thread pool");
186  ///
187  /// let some_param = "Mr White";
188  /// let res = executor_service.submit_sync(move || {
189  ///
190  ///   sleep(Duration::from_secs(5));
191  ///   println!("Hello {:}", some_param);
192  ///   println!("Long computation finished");
193  ///   2
194  /// }).expect("Failed to submit function");
195  ///
196  /// println!("Result: {:#?}", res);
197  /// assert_eq!(res, 2);
198  ///```
199  pub fn submit_sync(&mut self, fun: F) -> Result<T, ExecutorServiceError> {
200    let (s, r) = sync_channel(1);
201    self.dispatcher.send(DispatcherEventType::Execute(Some(s), fun))?;
202    Ok(r.recv()?)
203  }
204
205  ///
206  /// Submit a function and get a Future object to obtain the result
207  /// asynchronously when needed.
208  /// ```
209  /// use executor_service::Executors;
210  /// use std::thread::sleep;
211  /// use core::time::Duration;
212  ///
213  /// let mut executor_service = Executors::new_cached_thread_pool(Some(5)).expect("Failed to create the thread pool");
214  ///
215  /// let some_param = "Mr White";
216  /// let future = executor_service.submit_async(Box::new(move || {
217  ///
218  ///   sleep(Duration::from_secs(3));
219  ///   println!("Hello {:}", some_param);
220  ///   println!("Long computation finished");
221  ///   "Some string result".to_string()
222  /// })).expect("Failed to submit function");
223  ///
224  /// //Wait a bit more to see the future work.
225  /// println!("Main thread wait for 5 seconds");
226  /// sleep(Duration::from_secs(5));
227  /// let res = future.get().expect("Couldn't get a result");
228  /// println!("Result is {:}", &res);
229  /// assert_eq!(&res, "Some string result");
230  ///```
231  pub fn submit_async(&mut self, fun: F) -> Result<Future<T>, ExecutorServiceError>
232  {
233    let (s, r) = sync_channel(1);
234    self.dispatcher.send(DispatcherEventType::Execute(Some(s), fun))?;
235
236    Ok(Future {
237      result_receiver: r
238    })
239  }
240
241  pub fn pool_type(&self) -> &PoolType {
242    &self.pool_type
243  }
244
245
246  pub fn get_thread_count(&self) -> Result<u32, ExecutorServiceError> {
247    match self.thread_count.lock() {
248      Ok(lock) => Ok(*lock),
249      Err(_) => Err(ExecutorServiceError::ProcessingError)
250    }
251  }
252}
253
254impl<F, T> Drop for ExecutorService<F, T>
255  where F: FnOnce() -> T,
256        F: Send + 'static,
257        T: Send + 'static {
258  fn drop(&mut self) {
259    self.dispatcher.send(DispatcherEventType::Quit).unwrap();
260  }
261}
262
263
264pub struct Executors<F, T> where F: FnOnce() -> T,
265                                 F: Send + 'static,
266                                 T: Send + 'static {
267  _phantom: PhantomData<F>,
268}
269
270impl<F, T> Executors<F, T> where F: FnOnce() -> T,
271                                 F: Send + 'static,
272                                 T: Send + 'static {
273  ///
274  /// Creates a thread pool with a fixed size. All threads are initialized at first.
275  ///
276  /// `REMARKS`: The maximum value for [thread_count] is currently [MAX_THREAD_COUNT]
277  /// If you go beyond that, the function will fail, producing an [ExecutorServiceError::ParameterError]
278  ///
279  pub fn new_fixed_thread_pool(thread_count: u32) -> Result<ExecutorService<F, T>, ExecutorServiceError> {
280    if thread_count > MAX_THREAD_COUNT {
281      return Err(ExecutorServiceError::ProcessingError);
282    }
283
284    let thread_count_mutex = Arc::new(Mutex::new(thread_count));
285    let pool_type = PoolType::Fixed;
286    let sender = Self::prepare_pool(thread_count, pool_type.clone(), thread_count_mutex.clone())?;
287
288    Ok(ExecutorService {
289      dispatcher: sender,
290      pool_type,
291      thread_count: thread_count_mutex,
292    })
293  }
294
295
296  ///
297  /// Creates a cached thread pool with an optional initial thread count. If the initial
298  /// count is not provided, then a default of [DEFAULT_INITIAL_CACHED_THREAD_COUNT] threads will be initiated. When a new
299  /// task is posted to the pool, if there are no threads available, then a new thread
300  /// will be added to the pool and will then be cached. So the number of underlying
301  /// threads is likely to increase with respect to the needs.
302  ///
303  /// `REMARKS`: The maximum value for `initial_thread_count` is currently [MAX_THREAD_COUNT]. And
304  /// the maximum number of thread that can be created is also limited to [MAX_THREAD_COUNT] by design.
305  /// If more requests come and all threads are busy and we have a maximum of [MAX_THREAD_COUNT] threads,
306  /// then it will behave like a constant thread pool.
307  ///
308  pub fn new_cached_thread_pool(initial_thread_count: Option<u32>) -> Result<ExecutorService<F, T>, ExecutorServiceError> {
309    let initial_count = if let Some(count) = initial_thread_count {
310      if count > MAX_THREAD_COUNT {
311        return Err(ExecutorServiceError::ParameterError(format!("Max thread count is {:}", MAX_THREAD_COUNT)));
312      }
313      count
314    } else {
315      DEFAULT_INITIAL_CACHED_THREAD_COUNT
316    };
317
318
319    let pool_type = PoolType::Cached;
320    let thread_count_mutex = Arc::new(Mutex::new(initial_count));
321
322    let sender = Self::prepare_pool(initial_count, pool_type.clone(), thread_count_mutex.clone())?;
323
324    Ok(ExecutorService {
325      dispatcher: sender,
326      pool_type,
327      thread_count: thread_count_mutex,
328    })
329  }
330
331  fn prepare_pool(initial_count: u32, pool_type: PoolType, thread_count_mutex: Arc<Mutex<u32>>)
332                  -> Result<SyncSender<DispatcherEventType<F, T>>, ExecutorServiceError> {
333    let available = Arc::new(Mutex::new(vec![]));
334
335    let (sender, receiver) = sync_channel(1);
336
337    let pool_waiter = Arc::new(Condvar::new());
338
339    for i in 0..initial_count {
340      let (s, r) = channel();
341
342      if let Ok(mut lock) = available.lock() {
343        lock.push(s);
344      }
345
346      Self::create_thread(i, r, available.clone(), pool_waiter.clone())?;
347    }
348
349    Self::prepare_dispatcher(available, receiver, pool_waiter, pool_type, initial_count, thread_count_mutex)?;
350    Ok(sender)
351  }
352
353  fn create_thread(i: u32, r: Receiver<EventType<F, T>>,
354                   vec_clone: Arc<Mutex<Vec<Sender<EventType<F, T>>>>>,
355                   cv_clone: Arc<Condvar>)
356                   -> Result<(), ExecutorServiceError> {
357    thread::Builder::new()
358      .name(format!("Thread-{:}", i)).spawn(move || {
359      loop {
360        let fun = r.recv().unwrap();
361        match fun {
362          EventType::Execute(result_sender, sender, fun) => {
363            let t = fun();
364            if let Some(res_sender) = result_sender {
365              //TODO: Check the result
366              res_sender.send(t).unwrap();
367            }
368            if let Ok(mut lock) = vec_clone.lock() {
369              lock.push(sender);
370              cv_clone.notify_all();
371            }
372          }
373          EventType::Quit => {
374            trace!("{:} Received exit", thread::current().name().unwrap());
375            break;
376          }
377        }
378      }
379      //trace!("{:} Loop done", thread::current().name().unwrap())
380    })?;
381
382    Ok(())
383  }
384
385  fn prepare_dispatcher(available: Arc<Mutex<Vec<Sender<EventType<F, T>>>>>,
386                        receiver: Receiver<DispatcherEventType<F, T>>,
387                        pool_waiter: Arc<Condvar>,
388                        pool_type: PoolType,
389                        current_thread_count: u32,
390                        thread_count_mutex: Arc<Mutex<u32>>) -> Result<(), ExecutorServiceError> {
391    thread::Builder::new()
392      .name("Dispatcher".into())
393      .spawn(move || {
394        //shadowing deliberately
395        let mut current_thread_count = current_thread_count;
396        loop {
397          match receiver.recv().unwrap() {
398            DispatcherEventType::Execute(result_sender, func) => {
399              if let Ok(mut lock) = available.lock() {
400                if lock.is_empty() {
401                  //threads are busy
402                  match pool_type {
403                    PoolType::Cached => {
404                      //the pool is cached.
405                      if current_thread_count < MAX_THREAD_COUNT {
406                        //spawn a new thread
407                        let (s, r) = channel::<EventType<F, T>>();
408                        lock.push(s);
409                        //FIXME: use this result!
410                        Self::create_thread(current_thread_count, r, available.clone(), pool_waiter.clone());
411                        current_thread_count += 1;
412                        let mut count_lock = thread_count_mutex.lock().unwrap();
413                        *count_lock = current_thread_count;
414                      } else {
415                        //we already have a maximum, so wait again
416                        let _ = pool_waiter.wait(lock);
417                      }
418                    }
419                    PoolType::Fixed => {
420                      //the pool is fixed, we have to wait.
421                      let _ = pool_waiter.wait(lock);
422                    }
423                  }
424                }
425              };
426
427              if let Ok(mut lock) = available.lock() {
428                //trace!("Available: {:}", lock.len());
429                let the_sender = lock.pop().unwrap();
430                //trace!("Available: {:}", lock.len());
431                the_sender.send(EventType::Execute(result_sender, the_sender.clone(), func)).unwrap();
432              };
433            }
434            DispatcherEventType::Quit => {
435              //trace!("Dispatcher received Quit");
436              if let Ok(lock) = available.lock() {
437                for x in &*lock {
438                  trace!("AV Send quit");
439                  let _ = x.send(EventType::Quit);
440                }
441              }
442
443              break;
444            }
445          }
446        }
447        trace!("Dispatcher exit");
448      })?;
449
450    Ok(())
451  }
452}
453
454
455#[cfg(test)]
456mod tests {
457  use std::time::Duration;
458  use super::*;
459  use std::thread::sleep;
460  use std::thread;
461  use std::sync::mpsc::sync_channel;
462  use env_logger::{Builder, Env};
463  use log::{debug, info};
464
465  #[cfg(test)]
466  #[ctor::ctor]
467  fn init_env_logger() {
468    Builder::from_env(Env::default().default_filter_or("trace")).init();
469  }
470
471  #[test]
472  fn test_execute() -> Result<(), ExecutorServiceError> {
473    let max = 100;
474    let mut executor_service = Executors::new_fixed_thread_pool(10)?;
475
476    let (sender, receiver) = sync_channel(max);
477    for i in 0..max {
478      let moved_i = i;
479
480      let sender2 = sender.clone();
481
482      executor_service.execute(Box::new(move || {
483        sleep(Duration::from_millis(10));
484        info!("Hello from {:} {:}", thread::current().name().unwrap(), moved_i);
485        sender2.send(1).expect("Send failed");
486      }))?;
487    }
488
489    let mut latch_count = max;
490
491    loop {
492      let _ = &receiver.recv().unwrap();
493      latch_count -= 1;
494
495      if latch_count == 0 {
496        break; //all threads are done
497      }
498    };
499
500    Ok(())
501  }
502
503  #[test]
504  fn test_submit_sync() -> Result<(), ExecutorServiceError> {
505    let mut executor_service = Executors::new_fixed_thread_pool(2)?;
506
507    let some_param = "Mr White";
508    let res = executor_service.submit_sync(Box::new(move || {
509      info!("Long computation");
510      sleep(Duration::from_secs(5));
511      debug!("Hello {:}", some_param);
512      info!("Long computation finished");
513      2
514    }))?;
515
516    trace!("Result: {:#?}", res);
517    assert_eq!(res, 2);
518    Ok(())
519  }
520
521  #[test]
522  fn test_submit_async() -> Result<(), ExecutorServiceError> {
523    let mut executor_service = Executors::new_fixed_thread_pool(2)?;
524
525    let some_param = "Mr White";
526    let res: Future<String> = executor_service.submit_async(Box::new(move || {
527      info!("Long computation");
528      sleep(Duration::from_secs(5));
529      debug!("Hello {:}", some_param);
530      info!("Long computation finished");
531      "A string as a result".to_string()
532    }))?;
533
534    //Wait a bit more to see the future work.
535    info!("Main thread wait for 7 seconds");
536    sleep(Duration::from_secs(7));
537    info!("Main thread resumes after 7 seconds, consuming the future");
538    let the_string = res.get()?;
539    trace!("Result: {:#?}", &the_string);
540    assert_eq!(&the_string, "A string as a result");
541    Ok(())
542  }
543
544  #[test]
545  fn test_cahced_thread_pool_execute() -> Result<(), ExecutorServiceError> {
546    let mut executor_service = Executors::new_cached_thread_pool(None)?;
547
548    let (s, r) = channel();
549    let some_param = "Mr White";
550
551    for _ in 0..100 {
552      let s = s.clone();
553      sleep(Duration::from_millis(100));
554      debug!("Thread count is {:}", executor_service.get_thread_count().unwrap());
555      executor_service.execute(move || {
556        info!("Long computation Thread:{:}", thread::current().name().unwrap());
557        sleep(Duration::from_millis(15000));
558        debug!("Hello {:}", some_param);
559        info!("Long computation finished");
560        s.send("asdf").expect("Cannot send");
561      })?;
562    }
563
564    for _ in 0..100 {
565      r.recv().expect("Cannot receive");
566    }
567
568    Ok(())
569  }
570
571
572  #[test]
573  fn test_submit_async_cached() -> Result<(), ExecutorServiceError> {
574    let mut executor_service = Executors::new_cached_thread_pool(Some(5))?;
575
576    info!("{:?}", executor_service.get_thread_count());
577    let some_param = "Mr White";
578    let the_future: Future<String> = executor_service.submit_async(Box::new(move || {
579      info!("Long computation");
580      sleep(Duration::from_secs(5));
581      debug!("Hello {:}", some_param);
582      info!("Long computation finished");
583      "A string as a result".to_string()
584    }))?;
585
586    //Wait a bit more to see the future work.
587    info!("Main thread wait for 7 seconds");
588    sleep(Duration::from_secs(7));
589    info!("Main thread resumes after 7 seconds, consuming the future");
590
591
592    thread::spawn(move || {
593      let the_string = the_future.get().expect("No result");
594      trace!("Result: {:#?}", &the_string);
595      assert_eq!(&the_string, "A string as a result");
596    }).join().expect("Join failed");
597    Ok(())
598  }
599}