radiate_core/domain/
executor.rs1use crate::sync::WaitGroup;
2use crate::sync::WorkResult;
3use crate::sync::get_thread_pool;
4#[cfg(feature = "rayon")]
5use rayon::iter::{IntoParallelIterator, ParallelIterator};
6
7#[derive(Clone, Debug, Default)]
8pub enum Executor {
9 #[default]
10 Serial,
11 #[cfg(feature = "rayon")]
12 WorkerPool,
13 FixedSizedWorkerPool(usize),
14}
15
16impl Executor {
17 pub fn num_workers(&self) -> usize {
18 match self {
19 Executor::Serial => 1,
20 #[cfg(feature = "rayon")]
21 Executor::WorkerPool => rayon::current_num_threads(),
22 Executor::FixedSizedWorkerPool(num_workers) => *num_workers,
23 }
24 }
25
26 pub fn task<F, R>(&self, f: F) -> WorkResult<R>
27 where
28 F: FnOnce() -> R + Send + 'static,
29 R: Send + 'static,
30 {
31 match self {
32 Executor::Serial => {
33 let result = f();
34 let (tx, rx) = std::sync::mpsc::channel();
35 tx.send(result).unwrap();
36 WorkResult::new(rx)
37 }
38 Executor::FixedSizedWorkerPool(num_workers) => {
39 let pool = get_thread_pool(*num_workers);
40 pool.submit_with_result(f)
41 }
42 #[cfg(feature = "rayon")]
43 Executor::WorkerPool => {
44 let wg = WaitGroup::new();
45 let _wg_clone = wg.guard();
46 let (rx, _tx) = std::sync::mpsc::channel();
47 let work_result = WorkResult::new(_tx);
48 rayon::spawn_fifo(move || {
49 let res = f();
50 rx.send(res).unwrap();
51 });
52
53 work_result
54 }
55 }
56 }
57
58 pub fn execute<F, R>(&self, f: F) -> R
59 where
60 F: FnOnce() -> R + Send + 'static,
61 R: Send + 'static,
62 {
63 match self {
64 Executor::Serial => f(),
65 Executor::FixedSizedWorkerPool(num_workers) => {
66 get_thread_pool(*num_workers).submit_with_result(f).result()
67 }
68 #[cfg(feature = "rayon")]
69 Executor::WorkerPool => {
70 use std::sync::{Arc, Mutex};
71
72 let result = Arc::new(Mutex::new(None));
73 let result_clone = Arc::clone(&result);
74 let wg = WaitGroup::new();
75 let _wg_clone = wg.guard();
76 rayon::spawn_fifo(move || {
77 let res = f();
78 let mut guard = result_clone.lock().unwrap();
79 *guard = Some(res);
80 drop(_wg_clone);
81 });
82
83 wg.wait();
84
85 (*result.lock().unwrap()).take().unwrap()
86 }
87 }
88 }
89
90 pub fn execute_batch<F, R>(&self, f: Vec<F>) -> Vec<R>
91 where
92 F: FnOnce() -> R + Send + 'static,
93 R: Send + 'static,
94 {
95 match self {
96 Executor::Serial => f.into_iter().map(|func| func()).collect(),
97 Executor::FixedSizedWorkerPool(num_workers) => {
98 let pool = get_thread_pool(*num_workers);
99 let mut results = Vec::with_capacity(f.len());
100
101 for job in f {
102 results.push(pool.submit_with_result(|| job()));
103 }
104
105 results.into_iter().map(|r| r.result()).collect()
106 }
107 #[cfg(feature = "rayon")]
108 Executor::WorkerPool => f.into_par_iter().map(|func| func()).collect(),
109 }
110 }
111
112 pub fn submit<F>(&self, f: F)
113 where
114 F: FnOnce() + Send + 'static,
115 {
116 match self {
117 Executor::Serial => f(),
118 Executor::FixedSizedWorkerPool(num_workers) => {
119 let pool = get_thread_pool(*num_workers);
120 pool.submit(f)
121 }
122 #[cfg(feature = "rayon")]
123 Executor::WorkerPool => {
124 rayon::spawn_fifo(move || {
125 f();
126 });
127 }
128 }
129 }
130
131 pub fn submit_blocking<F>(&self, f: Vec<F>)
132 where
133 F: FnOnce() + Send + 'static,
134 {
135 match self {
136 Executor::Serial => {
137 for func in f {
138 func();
139 }
140 }
141 Executor::FixedSizedWorkerPool(num_workers) => {
142 let pool = get_thread_pool(*num_workers);
143 let wg = WaitGroup::new();
144 for job in f {
145 let guard = wg.guard();
146 pool.submit(move || {
147 job();
148 drop(guard);
149 });
150 }
151
152 wg.wait();
153 }
154 #[cfg(feature = "rayon")]
155 Executor::WorkerPool => {
156 let wg = WaitGroup::new();
157 let with_guards = f
158 .into_iter()
159 .map(|job| {
160 let guard = wg.guard();
161 move || {
162 job();
163 drop(guard);
164 }
165 })
166 .collect::<Vec<_>>();
167
168 with_guards.into_par_iter().for_each(|func| {
169 func();
170 });
171
172 wg.wait();
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::Executor;
181
182 #[test]
183 fn test_executor_serial() {
184 let executor = Executor::Serial;
185 let result = executor.execute(|| 42);
186 assert_eq!(result, 42);
187
188 let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
189 let results = executor.execute_batch(batch);
190 assert_eq!(results, vec![2, 4, 6]);
191 }
192
193 #[test]
194 fn test_executor_fixed_sized_worker_pool() {
195 let executor = Executor::FixedSizedWorkerPool(4);
196 let result = executor.execute(|| 42);
197
198 let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
199 let results = executor.execute_batch(batch);
200
201 assert_eq!(executor.num_workers(), 4);
202 assert_eq!(result, 42);
203 assert_eq!(results, vec![2, 4, 6]);
204 }
205}