1use std::collections::VecDeque;
10use std::future::Future;
11use std::sync::Arc;
12use std::thread::JoinHandle;
13
14use log::{debug, error};
15
16use crate::single::{SingleThreadTask, TaskExchange};
17use crate::{
18 CompletableTask, CurrentThreadExecutor, Exchanger, ExchangerError, TaskError, TaskHandle,
19};
20
21pub(crate) struct Worker {
22 handle: Option<JoinHandle<()>>,
23}
24
25impl Worker {
26 pub fn new(
27 exchanger: Exchanger<WorkerCommand>,
28 name: String,
29 ) -> Result<Worker, std::io::Error> {
30 let handle = std::thread::Builder::new()
31 .name(name.clone())
32 .spawn(move || {
33 let mut current = CurrentThreadExecutor::new();
34 loop {
35 let task = match exchanger.take() {
36 Ok(e) => match e {
37 WorkerCommand::Run(t) => t,
38 WorkerCommand::Close => {
39 debug!("Close command received, closing worker {:?}", name);
40 break;
41 }
42 },
43 Err(e) => {
44 if let ExchangerError::TaskError(e) = e {
45 if e != TaskError::ExecutorStoppingError {
46 error!("Error receiving new task: {e:?}");
47 }
48 } else {
49 error!("Error receiving new task: {e:?}");
50 }
51 break;
52 }
53 };
54 current.submit(task.inner);
55 current.run_until_complete();
56 }
57 })?;
58 Ok(Worker {
59 handle: Some(handle),
60 })
61 }
62}
63
64impl Drop for Worker {
65 fn drop(&mut self) {
66 if let Some(handle) = self.handle.take() {
67 let _res = handle.join();
68 }
69 }
70}
71
72#[derive(Debug)]
73pub struct Builder {
74 name: String,
75 max_workers: usize,
76}
77
78impl Default for Builder {
79 fn default() -> Self {
80 Builder::new()
81 }
82}
83
84impl Builder {
85 #[must_use]
86 pub fn new() -> Builder {
87 Builder {
88 name: String::new(),
89 max_workers: 1,
90 }
91 }
92
93 #[must_use]
94 pub fn with_name(self, name: &str) -> Self {
95 Builder {
96 name: name.to_string(),
97 ..self
98 }
99 }
100
101 #[must_use]
102 pub fn with_max_workers(self, max_workers: usize) -> Self {
103 Builder {
104 max_workers,
105 ..self
106 }
107 }
108
109 #[must_use]
113 #[cfg(feature = "num_cpus")]
114 pub fn with_num_cpu_workers(self) -> Self {
115 self.with_max_workers(num_cpus::get())
116 }
117
118 #[must_use]
119 pub fn build(self) -> MultiThreadedExecutor {
120 MultiThreadedExecutor {
121 exchanger: Exchanger::new(1),
122 workers: Default::default(),
123 max_workers: self.max_workers,
124 worker_ctr: 0,
125 name: self.name,
126 }
127 }
128}
129
130pub struct MultiThreadedExecutor {
131 exchanger: Exchanger<WorkerCommand>,
132 workers: VecDeque<Worker>,
133 max_workers: usize,
134 worker_ctr: usize,
135 name: String,
136}
137
138impl Default for MultiThreadedExecutor {
139 fn default() -> Self {
140 Self::new_single()
141 }
142}
143
144impl MultiThreadedExecutor {
145 pub fn new_single() -> MultiThreadedExecutor {
148 MultiThreadedExecutor::new_fixed(1)
149 }
150
151 pub fn new_fixed(worker_count: usize) -> MultiThreadedExecutor {
157 let mut mte = Builder::new()
158 .with_name(&format!("MTExec Fixed {worker_count}"))
159 .with_max_workers(worker_count)
160 .build();
161
162 for _i in 0..worker_count {
163 if let Err(e) = mte.add_worker() {
164 error!("Error adding worker: {e:?}");
165 }
166 }
167
168 mte
169 }
170
171 pub fn submit<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
185 &mut self,
186 fut: F,
187 ) -> Result<TaskHandle<T>, TaskError> {
188 let complete = Arc::new(CompletableTask::new());
189 let task = TaskExchange {
190 inner: Box::pin(SingleThreadTask::<T>::new(Box::pin(fut), complete.clone())),
191 };
192 let task = WorkerCommand::Run(task);
193
194 let task = match self.exchanger.try_push(task) {
195 Ok(()) => {
196 return Ok(TaskHandle {
197 completer: complete,
198 })
199 }
200 Err(e) => match e {
201 ExchangerError::TaskError(e) => {
202 return Err(e);
203 }
204 ExchangerError::ExchangerFull(task) => {
205 if self.worker_ctr < self.max_workers {
206 if let Err(e) = self.add_worker() {
207 error!("Tried to add worker, but could not: {e:?}");
208 }
209 }
210 task
211 }
212 ExchangerError::ExchangerEmpty => {
213 return Err(TaskError::ExchangerError);
214 }
215 },
216 };
217
218 if let Err(e) = self.exchanger.push(task) {
219 error!("Error exchanging task: {e:?}");
220 return match e {
221 ExchangerError::TaskError(e) => Err(e),
222 _ => {
223 error!("error exchanging task with worker: {e:?}");
224 Err(TaskError::ExchangerError)
225 }
226 };
227 }
228
229 Ok(TaskHandle {
230 completer: complete,
231 })
232 }
233
234 pub fn add_worker(&mut self) -> Result<(), std::io::Error> {
241 let name = format!("{}: Worker {}", self.name, self.worker_ctr);
242 self.worker_ctr += 1;
243 let worker = Worker::new(self.exchanger.clone(), name)?;
244 self.workers.push_back(worker);
245 Ok(())
246 }
247
248 pub fn remove_worker(&mut self) {
253 fn try_remove_idle(workers: &mut VecDeque<Worker>) -> bool {
256 for idx in 0..workers.len() {
257 let Some(worker) = workers.get(idx) else {
258 break;
259 };
260 let finished = if let Some(handle) = &worker.handle {
261 handle.is_finished()
262 } else {
263 false
264 };
265 if finished {
266 let worker = workers.swap_remove_back(idx);
267 if let Some(worker) = worker {
268 drop(worker);
270 }
271 return true;
272 }
273 }
274 false
275 }
276
277 if try_remove_idle(&mut self.workers) {
279 debug!("Successfully removed idle worker.");
280 return;
281 }
282
283 if let Err(e) = self.exchanger.push(WorkerCommand::Close) {
285 match e {
286 ExchangerError::TaskError(e) => {
287 if e != TaskError::ExecutorStoppingError {
288 error!("Error commanding worker to close: {e:?}");
289 }
290 }
291 e => {
292 error!("Error commanding worker to close: {e:?}");
293 }
294 }
295 }
296 }
297
298 pub fn shutdown(&self) {
299 self.exchanger.shutdown()
300 }
301}
302
303impl Drop for MultiThreadedExecutor {
304 fn drop(&mut self) {
305 self.exchanger.shutdown();
306 while !self.workers.is_empty() {
307 self.remove_worker();
308 }
309 }
310}
311
312pub(crate) enum WorkerCommand {
313 Close,
314 Run(TaskExchange),
315}
316
317#[cfg(test)]
318mod tests {
319 use std::time::Duration;
320
321 use log::{debug, trace};
322
323 use crate::MultiThreadedExecutor;
324
325 #[test]
326 pub fn test_one() {
327 let mut exec = MultiThreadedExecutor::new_fixed(1);
330
331 let mut answers = Vec::new();
332 for i in 0..100 {
333 answers.push(
334 exec.submit(async move {
335 std::thread::sleep(Duration::from_millis(1));
336 i
337 })
338 .unwrap(),
339 );
340 trace!("Submitted {i}");
341 }
342
343 exec.shutdown();
344
345 let mut i = 0;
346 for answer in answers {
347 let ans = answer.get().unwrap();
348 assert_eq!(ans, i);
349 i += 1;
350 }
351 }
352
353 #[test]
354 pub fn test_ten() {
355 let mut exec = MultiThreadedExecutor::new_fixed(10);
358
359 let mut answers = Vec::new();
360 for i in 0..1000 {
361 answers.push(
362 exec.submit(async move {
363 std::thread::sleep(Duration::from_millis(1));
364 i
365 })
366 .unwrap(),
367 );
368 trace!("Submitted {i}");
369 }
370 debug!("Submitted all");
371
372 let mut i = 0;
373 for answer in answers {
374 let ans = answer.get().unwrap();
375 assert_eq!(ans, i);
376 i += 1;
377 }
378 }
379}