Skip to main content

radiate_core/domain/
executor.rs

1use crate::sync::WaitGroup;
2use crate::sync::get_thread_pool;
3#[cfg(feature = "rayon")]
4use rayon::iter::{IntoParallelIterator, ParallelIterator};
5
6#[derive(Clone, Debug, Default)]
7pub enum Executor {
8    #[default]
9    Serial,
10    #[cfg(feature = "rayon")]
11    WorkerPool,
12    FixedSizedWorkerPool(usize),
13}
14
15impl Executor {
16    pub fn is_parallel(&self) -> bool {
17        match self {
18            Executor::Serial => false,
19            #[cfg(feature = "rayon")]
20            Executor::WorkerPool => true,
21            Executor::FixedSizedWorkerPool(_) => true,
22        }
23    }
24
25    pub fn num_workers(&self) -> usize {
26        match self {
27            Executor::Serial => 1,
28            #[cfg(feature = "rayon")]
29            Executor::WorkerPool => rayon::current_num_threads(),
30            Executor::FixedSizedWorkerPool(num_workers) => *num_workers,
31        }
32    }
33
34    pub fn execute<F, R>(&self, f: F) -> R
35    where
36        F: FnOnce() -> R + Send + 'static,
37        R: Send + 'static,
38    {
39        match self {
40            Executor::Serial => f(),
41            Executor::FixedSizedWorkerPool(num_workers) => {
42                get_thread_pool(*num_workers).submit_with_result(f).result()
43            }
44            #[cfg(feature = "rayon")]
45            Executor::WorkerPool => {
46                use std::sync::{Arc, Mutex};
47
48                let result = Arc::new(Mutex::new(None));
49                let result_clone = Arc::clone(&result);
50                let wg = WaitGroup::new();
51                let _wg_clone = wg.guard();
52                rayon::spawn_fifo(move || {
53                    let res = f();
54                    let mut guard = result_clone.lock().unwrap();
55                    *guard = Some(res);
56                    drop(_wg_clone);
57                });
58
59                wg.wait();
60
61                (*result.lock().unwrap()).take().unwrap()
62            }
63        }
64    }
65
66    pub fn execute_batch<I, F, R>(&self, jobs: I) -> Vec<R>
67    where
68        I: IntoIterator<Item = F>,
69        F: FnOnce() -> R + Send + 'static,
70        R: Send + 'static,
71    {
72        match self {
73            Executor::Serial => jobs.into_iter().map(|func| func()).collect(),
74            Executor::FixedSizedWorkerPool(num_workers) => {
75                let pool = get_thread_pool(*num_workers);
76                let iter = jobs.into_iter();
77                let mut results = Vec::with_capacity(iter.size_hint().0);
78
79                for job in iter {
80                    results.push(pool.submit_with_result(job));
81                }
82
83                results.into_iter().map(|r| r.result()).collect()
84            }
85            #[cfg(feature = "rayon")]
86            Executor::WorkerPool => jobs
87                .into_iter()
88                .collect::<Vec<_>>()
89                .into_par_iter()
90                .map(|func| func())
91                .collect(),
92        }
93    }
94
95    pub fn submit<F>(&self, f: F)
96    where
97        F: FnOnce() + Send + 'static,
98    {
99        match self {
100            Executor::Serial => f(),
101            Executor::FixedSizedWorkerPool(num_workers) => {
102                let pool = get_thread_pool(*num_workers);
103                pool.submit(f)
104            }
105            #[cfg(feature = "rayon")]
106            Executor::WorkerPool => {
107                rayon::spawn_fifo(move || {
108                    f();
109                });
110            }
111        }
112    }
113
114    pub fn submit_blocking<I, F>(&self, jobs: I)
115    where
116        I: IntoIterator<Item = F>,
117        F: FnOnce() + Send + 'static,
118    {
119        match self {
120            Executor::Serial => {
121                for func in jobs {
122                    func();
123                }
124            }
125            Executor::FixedSizedWorkerPool(num_workers) => {
126                let pool = get_thread_pool(*num_workers);
127                let wg = WaitGroup::new();
128                for job in jobs {
129                    let guard = wg.guard();
130                    pool.submit(move || {
131                        job();
132                        drop(guard);
133                    });
134                }
135
136                wg.wait();
137            }
138            #[cfg(feature = "rayon")]
139            Executor::WorkerPool => {
140                let wg = WaitGroup::new();
141                let with_guards = jobs
142                    .into_iter()
143                    .map(|job| {
144                        let guard = wg.guard();
145                        move || {
146                            job();
147                            drop(guard);
148                        }
149                    })
150                    .collect::<Vec<_>>();
151
152                with_guards.into_par_iter().for_each(|func| {
153                    func();
154                });
155
156                wg.wait();
157            }
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::Executor;
165
166    #[test]
167    fn test_executor_serial() {
168        let executor = Executor::Serial;
169        let result = executor.execute(|| 42);
170        assert_eq!(result, 42);
171
172        let batch: Vec<Box<dyn FnOnce() -> i32 + Send>> =
173            vec![Box::new(|| 1 * 2), Box::new(|| 2 * 2), Box::new(|| 3 * 2)];
174        let results = executor.execute_batch(batch);
175        assert_eq!(results, vec![2, 4, 6]);
176    }
177
178    #[test]
179    fn test_executor_fixed_sized_worker_pool() {
180        let executor = Executor::FixedSizedWorkerPool(4);
181        let result = executor.execute(|| 42);
182
183        let batch: Vec<Box<dyn FnOnce() -> i32 + Send>> =
184            vec![Box::new(|| 1 * 2), Box::new(|| 2 * 2), Box::new(|| 3 * 2)];
185        let results = executor.execute_batch(batch);
186
187        assert_eq!(executor.num_workers(), 4);
188        assert_eq!(result, 42);
189        assert_eq!(results, vec![2, 4, 6]);
190    }
191}