Skip to main content

radiate_core/domain/
executor.rs

1use crate::sync::WaitGroup;
2use crate::sync::get_thread_pool;
3#[cfg(feature = "rayon")]
4use rayon::iter::{IntoParallelIterator, ParallelIterator};
5
6#[derive(Clone, Debug, Default)]
7pub enum Executor {
8    #[default]
9    Serial,
10    #[cfg(feature = "rayon")]
11    WorkerPool,
12    FixedSizedWorkerPool(usize),
13}
14
15impl Executor {
16    pub fn is_parallel(&self) -> bool {
17        match self {
18            Executor::Serial => false,
19            #[cfg(feature = "rayon")]
20            Executor::WorkerPool => true,
21            Executor::FixedSizedWorkerPool(_) => true,
22        }
23    }
24
25    pub fn num_workers(&self) -> usize {
26        match self {
27            Executor::Serial => 1,
28            #[cfg(feature = "rayon")]
29            Executor::WorkerPool => rayon::current_num_threads(),
30            Executor::FixedSizedWorkerPool(num_workers) => *num_workers,
31        }
32    }
33
34    pub fn execute<F, R>(&self, f: F) -> R
35    where
36        F: FnOnce() -> R + Send + 'static,
37        R: Send + 'static,
38    {
39        match self {
40            Executor::Serial => f(),
41            Executor::FixedSizedWorkerPool(num_workers) => {
42                get_thread_pool(*num_workers).submit_with_result(f).result()
43            }
44            #[cfg(feature = "rayon")]
45            Executor::WorkerPool => {
46                use std::sync::{Arc, Mutex};
47
48                let result = Arc::new(Mutex::new(None));
49                let result_clone = Arc::clone(&result);
50                let wg = WaitGroup::new();
51                let _wg_clone = wg.guard();
52                rayon::spawn_fifo(move || {
53                    let res = f();
54                    let mut guard = result_clone.lock().unwrap();
55                    *guard = Some(res);
56                    drop(_wg_clone);
57                });
58
59                wg.wait();
60
61                (*result.lock().unwrap()).take().unwrap()
62            }
63        }
64    }
65
66    pub fn execute_batch<F, R>(&self, f: Vec<F>) -> Vec<R>
67    where
68        F: FnOnce() -> R + Send + 'static,
69        R: Send + 'static,
70    {
71        match self {
72            Executor::Serial => f.into_iter().map(|func| func()).collect(),
73            Executor::FixedSizedWorkerPool(num_workers) => {
74                let pool = get_thread_pool(*num_workers);
75                let mut results = Vec::with_capacity(f.len());
76
77                for job in f {
78                    results.push(pool.submit_with_result(|| job()));
79                }
80
81                results.into_iter().map(|r| r.result()).collect()
82            }
83            #[cfg(feature = "rayon")]
84            Executor::WorkerPool => f.into_par_iter().map(|func| func()).collect(),
85        }
86    }
87
88    pub fn submit<F>(&self, f: F)
89    where
90        F: FnOnce() + Send + 'static,
91    {
92        match self {
93            Executor::Serial => f(),
94            Executor::FixedSizedWorkerPool(num_workers) => {
95                let pool = get_thread_pool(*num_workers);
96                pool.submit(f)
97            }
98            #[cfg(feature = "rayon")]
99            Executor::WorkerPool => {
100                rayon::spawn_fifo(move || {
101                    f();
102                });
103            }
104        }
105    }
106
107    pub fn submit_blocking<F>(&self, f: Vec<F>)
108    where
109        F: FnOnce() + Send + 'static,
110    {
111        match self {
112            Executor::Serial => {
113                for func in f {
114                    func();
115                }
116            }
117            Executor::FixedSizedWorkerPool(num_workers) => {
118                let pool = get_thread_pool(*num_workers);
119                let wg = WaitGroup::new();
120                for job in f {
121                    let guard = wg.guard();
122                    pool.submit(move || {
123                        job();
124                        drop(guard);
125                    });
126                }
127
128                wg.wait();
129            }
130            #[cfg(feature = "rayon")]
131            Executor::WorkerPool => {
132                let wg = WaitGroup::new();
133                let with_guards = f
134                    .into_iter()
135                    .map(|job| {
136                        let guard = wg.guard();
137                        move || {
138                            job();
139                            drop(guard);
140                        }
141                    })
142                    .collect::<Vec<_>>();
143
144                with_guards.into_par_iter().for_each(|func| {
145                    func();
146                });
147
148                wg.wait();
149            }
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::Executor;
157
158    #[test]
159    fn test_executor_serial() {
160        let executor = Executor::Serial;
161        let result = executor.execute(|| 42);
162        assert_eq!(result, 42);
163
164        let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
165        let results = executor.execute_batch(batch);
166        assert_eq!(results, vec![2, 4, 6]);
167    }
168
169    #[test]
170    fn test_executor_fixed_sized_worker_pool() {
171        let executor = Executor::FixedSizedWorkerPool(4);
172        let result = executor.execute(|| 42);
173
174        let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
175        let results = executor.execute_batch(batch);
176
177        assert_eq!(executor.num_workers(), 4);
178        assert_eq!(result, 42);
179        assert_eq!(results, vec![2, 4, 6]);
180    }
181}