radiate_core/domain/
executor.rs1use crate::sync::WaitGroup;
2use crate::sync::get_thread_pool;
3#[cfg(feature = "rayon")]
4use rayon::iter::{IntoParallelIterator, ParallelIterator};
5
6#[derive(Clone, Debug, Default)]
7pub enum Executor {
8 #[default]
9 Serial,
10 #[cfg(feature = "rayon")]
11 WorkerPool,
12 FixedSizedWorkerPool(usize),
13}
14
15impl Executor {
16 pub fn is_parallel(&self) -> bool {
17 match self {
18 Executor::Serial => false,
19 #[cfg(feature = "rayon")]
20 Executor::WorkerPool => true,
21 Executor::FixedSizedWorkerPool(_) => true,
22 }
23 }
24
25 pub fn num_workers(&self) -> usize {
26 match self {
27 Executor::Serial => 1,
28 #[cfg(feature = "rayon")]
29 Executor::WorkerPool => rayon::current_num_threads(),
30 Executor::FixedSizedWorkerPool(num_workers) => *num_workers,
31 }
32 }
33
34 pub fn execute<F, R>(&self, f: F) -> R
35 where
36 F: FnOnce() -> R + Send + 'static,
37 R: Send + 'static,
38 {
39 match self {
40 Executor::Serial => f(),
41 Executor::FixedSizedWorkerPool(num_workers) => {
42 get_thread_pool(*num_workers).submit_with_result(f).result()
43 }
44 #[cfg(feature = "rayon")]
45 Executor::WorkerPool => {
46 use std::sync::{Arc, Mutex};
47
48 let result = Arc::new(Mutex::new(None));
49 let result_clone = Arc::clone(&result);
50 let wg = WaitGroup::new();
51 let _wg_clone = wg.guard();
52 rayon::spawn_fifo(move || {
53 let res = f();
54 let mut guard = result_clone.lock().unwrap();
55 *guard = Some(res);
56 drop(_wg_clone);
57 });
58
59 wg.wait();
60
61 (*result.lock().unwrap()).take().unwrap()
62 }
63 }
64 }
65
66 pub fn execute_batch<F, R>(&self, f: Vec<F>) -> Vec<R>
67 where
68 F: FnOnce() -> R + Send + 'static,
69 R: Send + 'static,
70 {
71 match self {
72 Executor::Serial => f.into_iter().map(|func| func()).collect(),
73 Executor::FixedSizedWorkerPool(num_workers) => {
74 let pool = get_thread_pool(*num_workers);
75 let mut results = Vec::with_capacity(f.len());
76
77 for job in f {
78 results.push(pool.submit_with_result(|| job()));
79 }
80
81 results.into_iter().map(|r| r.result()).collect()
82 }
83 #[cfg(feature = "rayon")]
84 Executor::WorkerPool => f.into_par_iter().map(|func| func()).collect(),
85 }
86 }
87
88 pub fn submit<F>(&self, f: F)
89 where
90 F: FnOnce() + Send + 'static,
91 {
92 match self {
93 Executor::Serial => f(),
94 Executor::FixedSizedWorkerPool(num_workers) => {
95 let pool = get_thread_pool(*num_workers);
96 pool.submit(f)
97 }
98 #[cfg(feature = "rayon")]
99 Executor::WorkerPool => {
100 rayon::spawn_fifo(move || {
101 f();
102 });
103 }
104 }
105 }
106
107 pub fn submit_blocking<F>(&self, f: Vec<F>)
108 where
109 F: FnOnce() + Send + 'static,
110 {
111 match self {
112 Executor::Serial => {
113 for func in f {
114 func();
115 }
116 }
117 Executor::FixedSizedWorkerPool(num_workers) => {
118 let pool = get_thread_pool(*num_workers);
119 let wg = WaitGroup::new();
120 for job in f {
121 let guard = wg.guard();
122 pool.submit(move || {
123 job();
124 drop(guard);
125 });
126 }
127
128 wg.wait();
129 }
130 #[cfg(feature = "rayon")]
131 Executor::WorkerPool => {
132 let wg = WaitGroup::new();
133 let with_guards = f
134 .into_iter()
135 .map(|job| {
136 let guard = wg.guard();
137 move || {
138 job();
139 drop(guard);
140 }
141 })
142 .collect::<Vec<_>>();
143
144 with_guards.into_par_iter().for_each(|func| {
145 func();
146 });
147
148 wg.wait();
149 }
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::Executor;
157
158 #[test]
159 fn test_executor_serial() {
160 let executor = Executor::Serial;
161 let result = executor.execute(|| 42);
162 assert_eq!(result, 42);
163
164 let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
165 let results = executor.execute_batch(batch);
166 assert_eq!(results, vec![2, 4, 6]);
167 }
168
169 #[test]
170 fn test_executor_fixed_sized_worker_pool() {
171 let executor = Executor::FixedSizedWorkerPool(4);
172 let result = executor.execute(|| 42);
173
174 let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
175 let results = executor.execute_batch(batch);
176
177 assert_eq!(executor.num_workers(), 4);
178 assert_eq!(result, 42);
179 assert_eq!(results, vec![2, 4, 6]);
180 }
181}