radiate_core/domain/
executor.rs

1use crate::sync::WaitGroup;
2use crate::sync::WorkResult;
3use crate::sync::get_thread_pool;
4#[cfg(feature = "rayon")]
5use rayon::iter::{IntoParallelIterator, ParallelIterator};
6
7#[derive(Clone, Debug, Default)]
8pub enum Executor {
9    #[default]
10    Serial,
11    #[cfg(feature = "rayon")]
12    WorkerPool,
13    FixedSizedWorkerPool(usize),
14}
15
16impl Executor {
17    pub fn num_workers(&self) -> usize {
18        match self {
19            Executor::Serial => 1,
20            #[cfg(feature = "rayon")]
21            Executor::WorkerPool => rayon::current_num_threads(),
22            Executor::FixedSizedWorkerPool(num_workers) => *num_workers,
23        }
24    }
25
26    pub fn task<F, R>(&self, f: F) -> WorkResult<R>
27    where
28        F: FnOnce() -> R + Send + 'static,
29        R: Send + 'static,
30    {
31        match self {
32            Executor::Serial => {
33                let result = f();
34                let (tx, rx) = std::sync::mpsc::channel();
35                tx.send(result).unwrap();
36                WorkResult::new(rx)
37            }
38            Executor::FixedSizedWorkerPool(num_workers) => {
39                let pool = get_thread_pool(*num_workers);
40                pool.submit_with_result(f)
41            }
42            #[cfg(feature = "rayon")]
43            Executor::WorkerPool => {
44                let wg = WaitGroup::new();
45                let _wg_clone = wg.guard();
46                let (rx, _tx) = std::sync::mpsc::channel();
47                let work_result = WorkResult::new(_tx);
48                rayon::spawn_fifo(move || {
49                    let res = f();
50                    rx.send(res).unwrap();
51                });
52
53                work_result
54            }
55        }
56    }
57
58    pub fn execute<F, R>(&self, f: F) -> R
59    where
60        F: FnOnce() -> R + Send + 'static,
61        R: Send + 'static,
62    {
63        match self {
64            Executor::Serial => f(),
65            Executor::FixedSizedWorkerPool(num_workers) => {
66                get_thread_pool(*num_workers).submit_with_result(f).result()
67            }
68            #[cfg(feature = "rayon")]
69            Executor::WorkerPool => {
70                use std::sync::{Arc, Mutex};
71
72                let result = Arc::new(Mutex::new(None));
73                let result_clone = Arc::clone(&result);
74                let wg = WaitGroup::new();
75                let _wg_clone = wg.guard();
76                rayon::spawn_fifo(move || {
77                    let res = f();
78                    let mut guard = result_clone.lock().unwrap();
79                    *guard = Some(res);
80                    drop(_wg_clone);
81                });
82
83                wg.wait();
84
85                (*result.lock().unwrap()).take().unwrap()
86            }
87        }
88    }
89
90    pub fn execute_batch<F, R>(&self, f: Vec<F>) -> Vec<R>
91    where
92        F: FnOnce() -> R + Send + 'static,
93        R: Send + 'static,
94    {
95        match self {
96            Executor::Serial => f.into_iter().map(|func| func()).collect(),
97            Executor::FixedSizedWorkerPool(num_workers) => {
98                let pool = get_thread_pool(*num_workers);
99                let mut results = Vec::with_capacity(f.len());
100
101                for job in f {
102                    results.push(pool.submit_with_result(|| job()));
103                }
104
105                results.into_iter().map(|r| r.result()).collect()
106            }
107            #[cfg(feature = "rayon")]
108            Executor::WorkerPool => f.into_par_iter().map(|func| func()).collect(),
109        }
110    }
111
112    pub fn submit<F>(&self, f: F)
113    where
114        F: FnOnce() + Send + 'static,
115    {
116        match self {
117            Executor::Serial => f(),
118            Executor::FixedSizedWorkerPool(num_workers) => {
119                let pool = get_thread_pool(*num_workers);
120                pool.submit(f)
121            }
122            #[cfg(feature = "rayon")]
123            Executor::WorkerPool => {
124                rayon::spawn_fifo(move || {
125                    f();
126                });
127            }
128        }
129    }
130
131    pub fn submit_blocking<F>(&self, f: Vec<F>)
132    where
133        F: FnOnce() + Send + 'static,
134    {
135        match self {
136            Executor::Serial => {
137                for func in f {
138                    func();
139                }
140            }
141            Executor::FixedSizedWorkerPool(num_workers) => {
142                let pool = get_thread_pool(*num_workers);
143                let wg = WaitGroup::new();
144                for job in f {
145                    let guard = wg.guard();
146                    pool.submit(move || {
147                        job();
148                        drop(guard);
149                    });
150                }
151
152                wg.wait();
153            }
154            #[cfg(feature = "rayon")]
155            Executor::WorkerPool => {
156                let wg = WaitGroup::new();
157                let with_guards = f
158                    .into_iter()
159                    .map(|job| {
160                        let guard = wg.guard();
161                        move || {
162                            job();
163                            drop(guard);
164                        }
165                    })
166                    .collect::<Vec<_>>();
167
168                with_guards.into_par_iter().for_each(|func| {
169                    func();
170                });
171
172                wg.wait();
173            }
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::Executor;
181
182    #[test]
183    fn test_executor_serial() {
184        let executor = Executor::Serial;
185        let result = executor.execute(|| 42);
186        assert_eq!(result, 42);
187
188        let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
189        let results = executor.execute_batch(batch);
190        assert_eq!(results, vec![2, 4, 6]);
191    }
192
193    #[test]
194    fn test_executor_fixed_sized_worker_pool() {
195        let executor = Executor::FixedSizedWorkerPool(4);
196        let result = executor.execute(|| 42);
197
198        let batch = vec![|| 1 * 2, || 2 * 2, || 3 * 2];
199        let results = executor.execute_batch(batch);
200
201        assert_eq!(executor.num_workers(), 4);
202        assert_eq!(result, 42);
203        assert_eq!(results, vec![2, 4, 6]);
204    }
205}