parallel_processor/execution_manager/
thread_pool.rs

1use crate::execution_manager::executor::{
2    AddressConsumer, AddressProducer, AsyncExecutor, ExecutorReceiver,
3};
4use crate::execution_manager::notifier::Notifier;
5use crate::execution_manager::objects_pool::ObjectsPool;
6use crate::execution_manager::packet::Packet;
7use crate::execution_manager::packets_channel::bounded::{
8    packets_channel_bounded, PacketsChannelReceiverBounded, PacketsChannelSenderBounded,
9};
10use crate::execution_manager::packets_channel::unbounded::{
11    packets_channel_unbounded, PacketsChannelSenderUnbounded,
12};
13use crate::execution_manager::scheduler::{
14    init_current_thread, run_blocking_op, uninit_current_thread, Scheduler,
15};
16use std::marker::PhantomData;
17use std::mem::transmute;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Barrier, Weak};
20use std::thread::JoinHandle;
21
22pub struct ExecThreadPool<E: AsyncExecutor> {
23    name: String,
24    threads_count: usize,
25    threads: Vec<JoinHandle<()>>,
26    scoped_thread_pool: Option<Arc<ScopedThreadPool>>,
27    allow_spawning: bool,
28    _phantom: PhantomData<E>,
29}
30
31pub struct ExecutorsHandle<E: AsyncExecutor> {
32    pub(crate) spawner: PacketsChannelSenderUnbounded<AddressConsumer<E>>,
33    pub(crate) channels_pool:
34        Arc<ObjectsPool<PacketsChannelReceiverBounded<Packet<E::InputPacket>>>>,
35    pub(crate) threads_count: usize,
36}
37impl<E: AsyncExecutor> Clone for ExecutorsHandle<E> {
38    fn clone(&self) -> Self {
39        ExecutorsHandle {
40            spawner: self.spawner.clone(),
41            channels_pool: self.channels_pool.clone(),
42            threads_count: self.threads_count,
43        }
44    }
45}
46
47impl<E: AsyncExecutor> ExecutorsHandle<E> {
48    pub fn add_input_data(
49        &self,
50        init_data: E::InitData,
51        data: impl ExactSizeIterator<Item = E::InputPacket>,
52    ) {
53        let init_data = Arc::new(init_data);
54        if E::ALLOW_PARALLEL_ADDRESS_EXECUTION {
55            let address = self.create_new_address(init_data, false);
56            for value in data {
57                address.send_packet(Packet::new_simple(value));
58            }
59        } else {
60            for value in data {
61                let address = self.create_new_address(init_data.clone(), false);
62                address.send_packet(Packet::new_simple(value));
63            }
64        }
65    }
66
67    pub fn create_new_address(
68        &self,
69        data: Arc<E::InitData>,
70        high_priority: bool,
71    ) -> AddressProducer<E::InputPacket> {
72        let channel = self.channels_pool.alloc_object();
73        let sender = channel.make_sender();
74        self.spawner.send_with_priority(
75            AddressConsumer {
76                init_data: data,
77                packets_queue: Arc::new(channel),
78            },
79            high_priority,
80            None,
81        );
82
83        AddressProducer {
84            packets_queue: sender,
85        }
86    }
87
88    pub fn create_new_address_with_limit(
89        &self,
90        data: Arc<E::InitData>,
91        high_priority: bool,
92        max_in_queue: usize,
93    ) -> AddressProducer<E::InputPacket> {
94        let channel = self.channels_pool.alloc_object();
95        let sender = channel.make_sender();
96        self.spawner.send_with_priority(
97            AddressConsumer {
98                init_data: data,
99                packets_queue: Arc::new(channel),
100            },
101            high_priority,
102            Some(max_in_queue),
103        );
104
105        AddressProducer {
106            packets_queue: sender,
107        }
108    }
109
110    pub fn create_new_addresses(
111        &self,
112        addr_iterator: impl Iterator<Item = Arc<E::InitData>>,
113        out_addresses: &mut Vec<AddressProducer<E::InputPacket>>,
114        max_in_queue: Option<usize>,
115        high_priority: bool,
116    ) {
117        out_addresses.reserve(addr_iterator.size_hint().0);
118
119        let mut count = 0;
120        self.spawner.send_batch(
121            addr_iterator.map(|init_data| {
122                let channel = self.channels_pool.alloc_object();
123                let sender = channel.make_sender();
124                count += 1;
125
126                out_addresses.push(AddressProducer {
127                    packets_queue: sender,
128                });
129                AddressConsumer {
130                    init_data,
131                    packets_queue: Arc::new(channel),
132                }
133            }),
134            max_in_queue,
135            high_priority,
136        );
137
138        assert!(
139            count <= self.threads_count,
140            "Cannot create more parallel addresses than the number of threads {} vs {}",
141            count,
142            self.threads_count
143        );
144    }
145
146    #[inline(always)]
147    pub fn get_pending_executors_count(&self) -> usize {
148        self.spawner.len()
149    }
150}
151
152impl<E: AsyncExecutor> ExecThreadPool<E> {
153    pub fn new(threads_count: usize, name: &str, allow_spawning: bool) -> Self {
154        Self {
155            name: name.to_string(),
156            threads_count,
157            threads: Vec::with_capacity(threads_count),
158            scoped_thread_pool: None,
159            allow_spawning,
160            _phantom: PhantomData,
161        }
162    }
163
164    pub fn start(
165        &mut self,
166        scheduler: Arc<Scheduler>,
167        global_params: &Arc<E::GlobalParams>,
168    ) -> ExecutorsHandle<E> {
169        let (addresses_sender, addresses_receiver) = packets_channel_unbounded();
170
171        let name = self.name.clone();
172        let scoped_thread_pool = if self.allow_spawning {
173            Some(Arc::new(ScopedThreadPool::new(
174                self.threads_count,
175                &name,
176                &scheduler,
177            )))
178        } else {
179            None
180        };
181        self.scoped_thread_pool = scoped_thread_pool.clone();
182
183        let handle = ExecutorsHandle {
184            spawner: addresses_sender,
185            // The channel has unlimited capacity, the actual capacity is limited by the thread count and the pool capacity
186            channels_pool: Arc::new(ObjectsPool::new(
187                self.threads_count,
188                if E::ALLOW_PARALLEL_ADDRESS_EXECUTION {
189                    self.threads_count * 2
190                } else {
191                    2
192                },
193            )),
194            threads_count: self.threads_count,
195        };
196
197        let barrier = Arc::new(Barrier::new(self.threads_count));
198
199        for i in 0..self.threads_count {
200            let global_params = global_params.clone();
201            let addresses_receiver = addresses_receiver.clone();
202            let scheduler = scheduler.clone();
203            let scoped_thread_pool = scoped_thread_pool.clone();
204            let barrier = barrier.clone();
205
206            let thread = std::thread::Builder::new()
207                .name(format!("{}-{}", self.name, i))
208                .spawn(move || {
209                    init_current_thread(scheduler.clone());
210
211                    let mut executor = E::new();
212                    executor.executor_main(
213                        &global_params,
214                        ExecutorReceiver {
215                            addresses_receiver,
216                            thread_pool: scoped_thread_pool,
217                            barrier,
218                        },
219                    );
220                })
221                .expect("Failed to spawn thread");
222            self.threads.push(thread);
223        }
224
225        handle
226    }
227
228    pub fn join(mut self) {
229        for thread in self.threads {
230            if let Err(e) = thread.join() {
231                eprintln!("Error joining thread: {:?}", e);
232            }
233        }
234        if let Some(thread_pool) = self.scoped_thread_pool.take() {
235            thread_pool.dispose();
236        }
237    }
238}
239
240pub struct ScopedThreadPool {
241    threads: Vec<JoinHandle<()>>,
242    execution_queue: PacketsChannelSenderBounded<Arc<ScopedTaskData>>,
243}
244
245struct ScopedTaskData {
246    running_count: AtomicUsize,
247    running_notifier: Notifier,
248    function: Weak<dyn Fn(usize) + Sync + Send>,
249    running_index: AtomicUsize,
250}
251
252impl ScopedThreadPool {
253    pub fn new(threads_count: usize, name: &str, scheduler: &Arc<Scheduler>) -> Self {
254        let (execution_queue, execution_dispatch) =
255            packets_channel_bounded::<Arc<ScopedTaskData>>(threads_count * 2);
256
257        Self {
258            threads: (0..threads_count)
259                .map(|i| {
260                    let execution_dispatch = execution_dispatch.clone();
261                    let scheduler = scheduler.clone();
262                    std::thread::Builder::new()
263                        .name(format!("{}-{}-threadpool", name, i))
264                        .spawn(move || {
265                            while let Some(task) = execution_dispatch.recv() {
266                                task.running_count.fetch_add(1, Ordering::SeqCst);
267                                if let Some(function) = task.function.upgrade() {
268                                    let running_index =
269                                        task.running_index.fetch_add(1, Ordering::Relaxed);
270                                    init_current_thread(scheduler.clone());
271                                    function(running_index);
272                                    uninit_current_thread();
273                                }
274                                let count = task.running_count.fetch_sub(1, Ordering::SeqCst);
275                                if count == 1 {
276                                    task.running_notifier.notify_all();
277                                }
278                            }
279                        })
280                        .unwrap()
281                })
282                .collect(),
283            execution_queue,
284        }
285    }
286
287    pub fn run_scoped_optional(
288        &self,
289        concurrency: usize,
290        callback: impl Fn(usize) + Sync + Send + Copy,
291    ) {
292        assert!(concurrency > 0);
293        if concurrency == 1 {
294            callback(0);
295        } else {
296            let function: Arc<dyn Fn(usize) + Sync + Send> = Arc::new(callback);
297            let function_lt_extended: Arc<dyn Fn(usize) + Sync + Send> =
298                unsafe { transmute(function) };
299
300            let task_data = Arc::new(ScopedTaskData {
301                running_count: AtomicUsize::new(0),
302                running_notifier: Notifier::new(),
303                function: Arc::downgrade(&function_lt_extended),
304                running_index: AtomicUsize::new(1),
305            });
306
307            for _ in 1..concurrency {
308                self.execution_queue.send(task_data.clone());
309            }
310
311            callback(0);
312
313            drop(function_lt_extended);
314            // Wait for all the running tasks to terminate
315            if task_data.running_count.load(Ordering::SeqCst) > 0 {
316                run_blocking_op(|| {
317                    task_data
318                        .running_notifier
319                        .wait_for_condition(|| task_data.running_count.load(Ordering::SeqCst) == 0);
320                });
321            }
322        }
323    }
324
325    pub fn dispose(&self) {
326        self.execution_queue.dispose();
327    }
328}
329
330impl Drop for ScopedThreadPool {
331    fn drop(&mut self) {
332        for thread in self.threads.drain(..) {
333            thread.join().unwrap();
334        }
335    }
336}