parallel_processor/execution_manager/
thread_pool.rs1use crate::execution_manager::executor::{
2 AddressConsumer, AddressProducer, AsyncExecutor, ExecutorReceiver,
3};
4use crate::execution_manager::notifier::Notifier;
5use crate::execution_manager::objects_pool::ObjectsPool;
6use crate::execution_manager::packet::Packet;
7use crate::execution_manager::packets_channel::bounded::{
8 packets_channel_bounded, PacketsChannelReceiverBounded, PacketsChannelSenderBounded,
9};
10use crate::execution_manager::packets_channel::unbounded::{
11 packets_channel_unbounded, PacketsChannelSenderUnbounded,
12};
13use crate::execution_manager::scheduler::{
14 init_current_thread, run_blocking_op, uninit_current_thread, Scheduler,
15};
16use std::marker::PhantomData;
17use std::mem::transmute;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Barrier, Weak};
20use std::thread::JoinHandle;
21
22pub struct ExecThreadPool<E: AsyncExecutor> {
23 name: String,
24 threads_count: usize,
25 threads: Vec<JoinHandle<()>>,
26 scoped_thread_pool: Option<Arc<ScopedThreadPool>>,
27 allow_spawning: bool,
28 _phantom: PhantomData<E>,
29}
30
31pub struct ExecutorsHandle<E: AsyncExecutor> {
32 pub(crate) spawner: PacketsChannelSenderUnbounded<AddressConsumer<E>>,
33 pub(crate) channels_pool:
34 Arc<ObjectsPool<PacketsChannelReceiverBounded<Packet<E::InputPacket>>>>,
35 pub(crate) threads_count: usize,
36}
37impl<E: AsyncExecutor> Clone for ExecutorsHandle<E> {
38 fn clone(&self) -> Self {
39 ExecutorsHandle {
40 spawner: self.spawner.clone(),
41 channels_pool: self.channels_pool.clone(),
42 threads_count: self.threads_count,
43 }
44 }
45}
46
47impl<E: AsyncExecutor> ExecutorsHandle<E> {
48 pub fn add_input_data(
49 &self,
50 init_data: E::InitData,
51 data: impl ExactSizeIterator<Item = E::InputPacket>,
52 ) {
53 let init_data = Arc::new(init_data);
54 if E::ALLOW_PARALLEL_ADDRESS_EXECUTION {
55 let address = self.create_new_address(init_data, false);
56 for value in data {
57 address.send_packet(Packet::new_simple(value));
58 }
59 } else {
60 for value in data {
61 let address = self.create_new_address(init_data.clone(), false);
62 address.send_packet(Packet::new_simple(value));
63 }
64 }
65 }
66
67 pub fn create_new_address(
68 &self,
69 data: Arc<E::InitData>,
70 high_priority: bool,
71 ) -> AddressProducer<E::InputPacket> {
72 let channel = self.channels_pool.alloc_object();
73 let sender = channel.make_sender();
74 self.spawner.send_with_priority(
75 AddressConsumer {
76 init_data: data,
77 packets_queue: Arc::new(channel),
78 },
79 high_priority,
80 None,
81 );
82
83 AddressProducer {
84 packets_queue: sender,
85 }
86 }
87
88 pub fn create_new_address_with_limit(
89 &self,
90 data: Arc<E::InitData>,
91 high_priority: bool,
92 max_in_queue: usize,
93 ) -> AddressProducer<E::InputPacket> {
94 let channel = self.channels_pool.alloc_object();
95 let sender = channel.make_sender();
96 self.spawner.send_with_priority(
97 AddressConsumer {
98 init_data: data,
99 packets_queue: Arc::new(channel),
100 },
101 high_priority,
102 Some(max_in_queue),
103 );
104
105 AddressProducer {
106 packets_queue: sender,
107 }
108 }
109
110 pub fn create_new_addresses(
111 &self,
112 addr_iterator: impl Iterator<Item = Arc<E::InitData>>,
113 out_addresses: &mut Vec<AddressProducer<E::InputPacket>>,
114 max_in_queue: Option<usize>,
115 high_priority: bool,
116 ) {
117 out_addresses.reserve(addr_iterator.size_hint().0);
118
119 let mut count = 0;
120 self.spawner.send_batch(
121 addr_iterator.map(|init_data| {
122 let channel = self.channels_pool.alloc_object();
123 let sender = channel.make_sender();
124 count += 1;
125
126 out_addresses.push(AddressProducer {
127 packets_queue: sender,
128 });
129 AddressConsumer {
130 init_data,
131 packets_queue: Arc::new(channel),
132 }
133 }),
134 max_in_queue,
135 high_priority,
136 );
137
138 assert!(
139 count <= self.threads_count,
140 "Cannot create more parallel addresses than the number of threads {} vs {}",
141 count,
142 self.threads_count
143 );
144 }
145
146 #[inline(always)]
147 pub fn get_pending_executors_count(&self) -> usize {
148 self.spawner.len()
149 }
150}
151
152impl<E: AsyncExecutor> ExecThreadPool<E> {
153 pub fn new(threads_count: usize, name: &str, allow_spawning: bool) -> Self {
154 Self {
155 name: name.to_string(),
156 threads_count,
157 threads: Vec::with_capacity(threads_count),
158 scoped_thread_pool: None,
159 allow_spawning,
160 _phantom: PhantomData,
161 }
162 }
163
164 pub fn start(
165 &mut self,
166 scheduler: Arc<Scheduler>,
167 global_params: &Arc<E::GlobalParams>,
168 ) -> ExecutorsHandle<E> {
169 let (addresses_sender, addresses_receiver) = packets_channel_unbounded();
170
171 let name = self.name.clone();
172 let scoped_thread_pool = if self.allow_spawning {
173 Some(Arc::new(ScopedThreadPool::new(
174 self.threads_count,
175 &name,
176 &scheduler,
177 )))
178 } else {
179 None
180 };
181 self.scoped_thread_pool = scoped_thread_pool.clone();
182
183 let handle = ExecutorsHandle {
184 spawner: addresses_sender,
185 channels_pool: Arc::new(ObjectsPool::new(
187 self.threads_count,
188 if E::ALLOW_PARALLEL_ADDRESS_EXECUTION {
189 self.threads_count * 2
190 } else {
191 2
192 },
193 )),
194 threads_count: self.threads_count,
195 };
196
197 let barrier = Arc::new(Barrier::new(self.threads_count));
198
199 for i in 0..self.threads_count {
200 let global_params = global_params.clone();
201 let addresses_receiver = addresses_receiver.clone();
202 let scheduler = scheduler.clone();
203 let scoped_thread_pool = scoped_thread_pool.clone();
204 let barrier = barrier.clone();
205
206 let thread = std::thread::Builder::new()
207 .name(format!("{}-{}", self.name, i))
208 .spawn(move || {
209 init_current_thread(scheduler.clone());
210
211 let mut executor = E::new();
212 executor.executor_main(
213 &global_params,
214 ExecutorReceiver {
215 addresses_receiver,
216 thread_pool: scoped_thread_pool,
217 barrier,
218 },
219 );
220 })
221 .expect("Failed to spawn thread");
222 self.threads.push(thread);
223 }
224
225 handle
226 }
227
228 pub fn join(mut self) {
229 for thread in self.threads {
230 if let Err(e) = thread.join() {
231 eprintln!("Error joining thread: {:?}", e);
232 }
233 }
234 if let Some(thread_pool) = self.scoped_thread_pool.take() {
235 thread_pool.dispose();
236 }
237 }
238}
239
240pub struct ScopedThreadPool {
241 threads: Vec<JoinHandle<()>>,
242 execution_queue: PacketsChannelSenderBounded<Arc<ScopedTaskData>>,
243}
244
245struct ScopedTaskData {
246 running_count: AtomicUsize,
247 running_notifier: Notifier,
248 function: Weak<dyn Fn(usize) + Sync + Send>,
249 running_index: AtomicUsize,
250}
251
252impl ScopedThreadPool {
253 pub fn new(threads_count: usize, name: &str, scheduler: &Arc<Scheduler>) -> Self {
254 let (execution_queue, execution_dispatch) =
255 packets_channel_bounded::<Arc<ScopedTaskData>>(threads_count * 2);
256
257 Self {
258 threads: (0..threads_count)
259 .map(|i| {
260 let execution_dispatch = execution_dispatch.clone();
261 let scheduler = scheduler.clone();
262 std::thread::Builder::new()
263 .name(format!("{}-{}-threadpool", name, i))
264 .spawn(move || {
265 while let Some(task) = execution_dispatch.recv() {
266 task.running_count.fetch_add(1, Ordering::SeqCst);
267 if let Some(function) = task.function.upgrade() {
268 let running_index =
269 task.running_index.fetch_add(1, Ordering::Relaxed);
270 init_current_thread(scheduler.clone());
271 function(running_index);
272 uninit_current_thread();
273 }
274 let count = task.running_count.fetch_sub(1, Ordering::SeqCst);
275 if count == 1 {
276 task.running_notifier.notify_all();
277 }
278 }
279 })
280 .unwrap()
281 })
282 .collect(),
283 execution_queue,
284 }
285 }
286
287 pub fn run_scoped_optional(
288 &self,
289 concurrency: usize,
290 callback: impl Fn(usize) + Sync + Send + Copy,
291 ) {
292 assert!(concurrency > 0);
293 if concurrency == 1 {
294 callback(0);
295 } else {
296 let function: Arc<dyn Fn(usize) + Sync + Send> = Arc::new(callback);
297 let function_lt_extended: Arc<dyn Fn(usize) + Sync + Send> =
298 unsafe { transmute(function) };
299
300 let task_data = Arc::new(ScopedTaskData {
301 running_count: AtomicUsize::new(0),
302 running_notifier: Notifier::new(),
303 function: Arc::downgrade(&function_lt_extended),
304 running_index: AtomicUsize::new(1),
305 });
306
307 for _ in 1..concurrency {
308 self.execution_queue.send(task_data.clone());
309 }
310
311 callback(0);
312
313 drop(function_lt_extended);
314 if task_data.running_count.load(Ordering::SeqCst) > 0 {
316 run_blocking_op(|| {
317 task_data
318 .running_notifier
319 .wait_for_condition(|| task_data.running_count.load(Ordering::SeqCst) == 0);
320 });
321 }
322 }
323 }
324
325 pub fn dispose(&self) {
326 self.execution_queue.dispose();
327 }
328}
329
330impl Drop for ScopedThreadPool {
331 fn drop(&mut self) {
332 for thread in self.threads.drain(..) {
333 thread.join().unwrap();
334 }
335 }
336}