1use std::fmt::{Display, Formatter};
2use std::marker::PhantomData;
3use std::thread;
4use std::sync::mpsc::{channel, sync_channel, Sender, SyncSender, Receiver, SendError, RecvError};
5use std::sync::{Arc, Mutex, Condvar};
6use log::trace;
7
8pub type Runnable<T> = dyn Send + 'static + FnOnce() -> T;
9
10pub const MAX_THREAD_COUNT: u32 = 150;
17
18pub const DEFAULT_INITIAL_CACHED_THREAD_COUNT: u32 = 10;
22
23#[derive(Debug, Clone)]
24pub enum PoolType {
25 Cached,
26 Fixed,
27}
28
29#[derive(Debug)]
30pub enum ExecutorServiceError {
31 ParameterError(String),
32 IOError(std::io::Error),
33 ProcessingError,
34 ResultReceptionError,
35}
36
37impl<T> From<SendError<T>> for ExecutorServiceError {
38 fn from(_: SendError<T>) -> Self {
39 ExecutorServiceError::ProcessingError
40 }
41}
42
43impl From<RecvError> for ExecutorServiceError {
44 fn from(_: RecvError) -> Self {
45 ExecutorServiceError::ResultReceptionError
46 }
47}
48
49impl From<std::io::Error> for ExecutorServiceError {
50 fn from(value: std::io::Error) -> Self {
51 ExecutorServiceError::IOError(value)
52 }
53}
54
55impl Display for ExecutorServiceError {
56 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57 match self {
58 ExecutorServiceError::ParameterError(message) => write!(f, "{:}: {:}", "ParameterError", message.as_str()),
59 ExecutorServiceError::IOError(io_error) => write!(f, "{:}: {:}", "IOError", io_error),
60 ExecutorServiceError::ResultReceptionError => write!(f, "{:}", "ResultReceptionError"),
61 ExecutorServiceError::ProcessingError => write!(f, "{:}", "ProcessingError"),
62 }
63 }
64}
65
66impl std::error::Error for ExecutorServiceError {}
67
68pub struct Future<T: Send + 'static> {
69 result_receiver: Receiver<T>,
70}
71
72
73impl<T: Send + 'static> Future<T> {
74 pub fn get(&self) -> Result<T, ExecutorServiceError> {
75 Ok(self.result_receiver.recv()?)
76 }
77}
78
79enum DispatcherEventType<F, T>
80 where F: FnOnce() -> T,
81 T: Send + 'static,
82 F: Send + 'static
83{
84 Execute(Option<SyncSender<T>>, F),
85 Quit,
86}
87
88enum EventType<F, T>
89 where F: FnOnce() -> T,
90 T: Send + 'static,
91 F: Send + 'static
92{
93 Execute(Option<SyncSender<T>>, Sender<Self>, F),
94 Quit,
95}
96
97
98impl<F: Send + 'static + FnOnce() -> T, T: Send + 'static> Display for DispatcherEventType<F, T> {
99 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100 write!(f, "EventType::{:}",
101 match self {
102 Self::Execute(_, _) => "Execute",
103 Self::Quit => "Quit",
104 }
105 )
106 }
107}
108
109impl<F: Send + 'static + FnOnce() -> T, T: Send + 'static> Display for EventType<F, T> {
110 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
111 write!(f, "EventType::{:}",
112 match self {
113 Self::Execute(_, _, _) => "Execute",
114 Self::Quit => "Quit",
115 }
116 )
117 }
118}
119
120pub struct ExecutorService<F, T>
143 where F: FnOnce() -> T,
144 F: Send + 'static,
145 T: Send + 'static {
146 dispatcher: SyncSender<DispatcherEventType<F, T>>,
147 pool_type: PoolType,
148 thread_count: Arc<Mutex<u32>>,
149}
150
151impl<F, T> ExecutorService<F, T>
152 where F: FnOnce() -> T,
153 F: Send + 'static,
154 T: Send + 'static {
155 pub fn execute(&mut self, fun: F) -> Result<(), ExecutorServiceError> {
175 Ok(self.dispatcher.send(DispatcherEventType::Execute(None, fun))?)
176 }
177
178 pub fn submit_sync(&mut self, fun: F) -> Result<T, ExecutorServiceError> {
200 let (s, r) = sync_channel(1);
201 self.dispatcher.send(DispatcherEventType::Execute(Some(s), fun))?;
202 Ok(r.recv()?)
203 }
204
205 pub fn submit_async(&mut self, fun: F) -> Result<Future<T>, ExecutorServiceError>
232 {
233 let (s, r) = sync_channel(1);
234 self.dispatcher.send(DispatcherEventType::Execute(Some(s), fun))?;
235
236 Ok(Future {
237 result_receiver: r
238 })
239 }
240
241 pub fn pool_type(&self) -> &PoolType {
242 &self.pool_type
243 }
244
245
246 pub fn get_thread_count(&self) -> Result<u32, ExecutorServiceError> {
247 match self.thread_count.lock() {
248 Ok(lock) => Ok(*lock),
249 Err(_) => Err(ExecutorServiceError::ProcessingError)
250 }
251 }
252}
253
254impl<F, T> Drop for ExecutorService<F, T>
255 where F: FnOnce() -> T,
256 F: Send + 'static,
257 T: Send + 'static {
258 fn drop(&mut self) {
259 self.dispatcher.send(DispatcherEventType::Quit).unwrap();
260 }
261}
262
263
264pub struct Executors<F, T> where F: FnOnce() -> T,
265 F: Send + 'static,
266 T: Send + 'static {
267 _phantom: PhantomData<F>,
268}
269
270impl<F, T> Executors<F, T> where F: FnOnce() -> T,
271 F: Send + 'static,
272 T: Send + 'static {
273 pub fn new_fixed_thread_pool(thread_count: u32) -> Result<ExecutorService<F, T>, ExecutorServiceError> {
280 if thread_count > MAX_THREAD_COUNT {
281 return Err(ExecutorServiceError::ProcessingError);
282 }
283
284 let thread_count_mutex = Arc::new(Mutex::new(thread_count));
285 let pool_type = PoolType::Fixed;
286 let sender = Self::prepare_pool(thread_count, pool_type.clone(), thread_count_mutex.clone())?;
287
288 Ok(ExecutorService {
289 dispatcher: sender,
290 pool_type,
291 thread_count: thread_count_mutex,
292 })
293 }
294
295
296 pub fn new_cached_thread_pool(initial_thread_count: Option<u32>) -> Result<ExecutorService<F, T>, ExecutorServiceError> {
309 let initial_count = if let Some(count) = initial_thread_count {
310 if count > MAX_THREAD_COUNT {
311 return Err(ExecutorServiceError::ParameterError(format!("Max thread count is {:}", MAX_THREAD_COUNT)));
312 }
313 count
314 } else {
315 DEFAULT_INITIAL_CACHED_THREAD_COUNT
316 };
317
318
319 let pool_type = PoolType::Cached;
320 let thread_count_mutex = Arc::new(Mutex::new(initial_count));
321
322 let sender = Self::prepare_pool(initial_count, pool_type.clone(), thread_count_mutex.clone())?;
323
324 Ok(ExecutorService {
325 dispatcher: sender,
326 pool_type,
327 thread_count: thread_count_mutex,
328 })
329 }
330
331 fn prepare_pool(initial_count: u32, pool_type: PoolType, thread_count_mutex: Arc<Mutex<u32>>)
332 -> Result<SyncSender<DispatcherEventType<F, T>>, ExecutorServiceError> {
333 let available = Arc::new(Mutex::new(vec![]));
334
335 let (sender, receiver) = sync_channel(1);
336
337 let pool_waiter = Arc::new(Condvar::new());
338
339 for i in 0..initial_count {
340 let (s, r) = channel();
341
342 if let Ok(mut lock) = available.lock() {
343 lock.push(s);
344 }
345
346 Self::create_thread(i, r, available.clone(), pool_waiter.clone())?;
347 }
348
349 Self::prepare_dispatcher(available, receiver, pool_waiter, pool_type, initial_count, thread_count_mutex)?;
350 Ok(sender)
351 }
352
353 fn create_thread(i: u32, r: Receiver<EventType<F, T>>,
354 vec_clone: Arc<Mutex<Vec<Sender<EventType<F, T>>>>>,
355 cv_clone: Arc<Condvar>)
356 -> Result<(), ExecutorServiceError> {
357 thread::Builder::new()
358 .name(format!("Thread-{:}", i)).spawn(move || {
359 loop {
360 let fun = r.recv().unwrap();
361 match fun {
362 EventType::Execute(result_sender, sender, fun) => {
363 let t = fun();
364 if let Some(res_sender) = result_sender {
365 res_sender.send(t).unwrap();
367 }
368 if let Ok(mut lock) = vec_clone.lock() {
369 lock.push(sender);
370 cv_clone.notify_all();
371 }
372 }
373 EventType::Quit => {
374 trace!("{:} Received exit", thread::current().name().unwrap());
375 break;
376 }
377 }
378 }
379 })?;
381
382 Ok(())
383 }
384
385 fn prepare_dispatcher(available: Arc<Mutex<Vec<Sender<EventType<F, T>>>>>,
386 receiver: Receiver<DispatcherEventType<F, T>>,
387 pool_waiter: Arc<Condvar>,
388 pool_type: PoolType,
389 current_thread_count: u32,
390 thread_count_mutex: Arc<Mutex<u32>>) -> Result<(), ExecutorServiceError> {
391 thread::Builder::new()
392 .name("Dispatcher".into())
393 .spawn(move || {
394 let mut current_thread_count = current_thread_count;
396 loop {
397 match receiver.recv().unwrap() {
398 DispatcherEventType::Execute(result_sender, func) => {
399 if let Ok(mut lock) = available.lock() {
400 if lock.is_empty() {
401 match pool_type {
403 PoolType::Cached => {
404 if current_thread_count < MAX_THREAD_COUNT {
406 let (s, r) = channel::<EventType<F, T>>();
408 lock.push(s);
409 Self::create_thread(current_thread_count, r, available.clone(), pool_waiter.clone());
411 current_thread_count += 1;
412 let mut count_lock = thread_count_mutex.lock().unwrap();
413 *count_lock = current_thread_count;
414 } else {
415 let _ = pool_waiter.wait(lock);
417 }
418 }
419 PoolType::Fixed => {
420 let _ = pool_waiter.wait(lock);
422 }
423 }
424 }
425 };
426
427 if let Ok(mut lock) = available.lock() {
428 let the_sender = lock.pop().unwrap();
430 the_sender.send(EventType::Execute(result_sender, the_sender.clone(), func)).unwrap();
432 };
433 }
434 DispatcherEventType::Quit => {
435 if let Ok(lock) = available.lock() {
437 for x in &*lock {
438 trace!("AV Send quit");
439 let _ = x.send(EventType::Quit);
440 }
441 }
442
443 break;
444 }
445 }
446 }
447 trace!("Dispatcher exit");
448 })?;
449
450 Ok(())
451 }
452}
453
454
455#[cfg(test)]
456mod tests {
457 use std::time::Duration;
458 use super::*;
459 use std::thread::sleep;
460 use std::thread;
461 use std::sync::mpsc::sync_channel;
462 use env_logger::{Builder, Env};
463 use log::{debug, info};
464
465 #[cfg(test)]
466 #[ctor::ctor]
467 fn init_env_logger() {
468 Builder::from_env(Env::default().default_filter_or("trace")).init();
469 }
470
471 #[test]
472 fn test_execute() -> Result<(), ExecutorServiceError> {
473 let max = 100;
474 let mut executor_service = Executors::new_fixed_thread_pool(10)?;
475
476 let (sender, receiver) = sync_channel(max);
477 for i in 0..max {
478 let moved_i = i;
479
480 let sender2 = sender.clone();
481
482 executor_service.execute(Box::new(move || {
483 sleep(Duration::from_millis(10));
484 info!("Hello from {:} {:}", thread::current().name().unwrap(), moved_i);
485 sender2.send(1).expect("Send failed");
486 }))?;
487 }
488
489 let mut latch_count = max;
490
491 loop {
492 let _ = &receiver.recv().unwrap();
493 latch_count -= 1;
494
495 if latch_count == 0 {
496 break; }
498 };
499
500 Ok(())
501 }
502
503 #[test]
504 fn test_submit_sync() -> Result<(), ExecutorServiceError> {
505 let mut executor_service = Executors::new_fixed_thread_pool(2)?;
506
507 let some_param = "Mr White";
508 let res = executor_service.submit_sync(Box::new(move || {
509 info!("Long computation");
510 sleep(Duration::from_secs(5));
511 debug!("Hello {:}", some_param);
512 info!("Long computation finished");
513 2
514 }))?;
515
516 trace!("Result: {:#?}", res);
517 assert_eq!(res, 2);
518 Ok(())
519 }
520
521 #[test]
522 fn test_submit_async() -> Result<(), ExecutorServiceError> {
523 let mut executor_service = Executors::new_fixed_thread_pool(2)?;
524
525 let some_param = "Mr White";
526 let res: Future<String> = executor_service.submit_async(Box::new(move || {
527 info!("Long computation");
528 sleep(Duration::from_secs(5));
529 debug!("Hello {:}", some_param);
530 info!("Long computation finished");
531 "A string as a result".to_string()
532 }))?;
533
534 info!("Main thread wait for 7 seconds");
536 sleep(Duration::from_secs(7));
537 info!("Main thread resumes after 7 seconds, consuming the future");
538 let the_string = res.get()?;
539 trace!("Result: {:#?}", &the_string);
540 assert_eq!(&the_string, "A string as a result");
541 Ok(())
542 }
543
544 #[test]
545 fn test_cahced_thread_pool_execute() -> Result<(), ExecutorServiceError> {
546 let mut executor_service = Executors::new_cached_thread_pool(None)?;
547
548 let (s, r) = channel();
549 let some_param = "Mr White";
550
551 for _ in 0..100 {
552 let s = s.clone();
553 sleep(Duration::from_millis(100));
554 debug!("Thread count is {:}", executor_service.get_thread_count().unwrap());
555 executor_service.execute(move || {
556 info!("Long computation Thread:{:}", thread::current().name().unwrap());
557 sleep(Duration::from_millis(15000));
558 debug!("Hello {:}", some_param);
559 info!("Long computation finished");
560 s.send("asdf").expect("Cannot send");
561 })?;
562 }
563
564 for _ in 0..100 {
565 r.recv().expect("Cannot receive");
566 }
567
568 Ok(())
569 }
570
571
572 #[test]
573 fn test_submit_async_cached() -> Result<(), ExecutorServiceError> {
574 let mut executor_service = Executors::new_cached_thread_pool(Some(5))?;
575
576 info!("{:?}", executor_service.get_thread_count());
577 let some_param = "Mr White";
578 let the_future: Future<String> = executor_service.submit_async(Box::new(move || {
579 info!("Long computation");
580 sleep(Duration::from_secs(5));
581 debug!("Hello {:}", some_param);
582 info!("Long computation finished");
583 "A string as a result".to_string()
584 }))?;
585
586 info!("Main thread wait for 7 seconds");
588 sleep(Duration::from_secs(7));
589 info!("Main thread resumes after 7 seconds, consuming the future");
590
591
592 thread::spawn(move || {
593 let the_string = the_future.get().expect("No result");
594 trace!("Result: {:#?}", &the_string);
595 assert_eq!(&the_string, "A string as a result");
596 }).join().expect("Join failed");
597 Ok(())
598 }
599}