parallel_processor/execution_manager/
execution_context.rs

1use crate::execution_manager::async_channel::{AsyncChannel, DoublePriorityAsyncChannel};
2use crate::execution_manager::executor::AsyncExecutor;
3use crate::execution_manager::executor_address::{ExecutorAddress, WeakExecutorAddress};
4use crate::execution_manager::memory_tracker::MemoryTrackerManager;
5use crate::execution_manager::objects_pool::{ObjectsPool, PoolObject, PoolObjectTrait};
6use crate::execution_manager::packet::{Packet, PacketAny, PacketTrait, PacketsPool};
7use crate::execution_manager::thread_pool::ExecutorsHandle;
8use dashmap::DashMap;
9use parking_lot::{Condvar, Mutex};
10use std::any::{Any, TypeId};
11use std::cell::UnsafeCell;
12use std::collections::HashMap;
13use std::ops::Deref;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::{Arc, Weak};
16use std::time::Duration;
17use tokio::sync::Semaphore;
18
19const ADDRESSES_BUFFER_SIZE: usize = 1;
20
21pub enum PoolAllocMode {
22    None,
23    Shared { capacity: usize },
24    Distinct { capacity: usize },
25}
26
27pub(crate) enum PacketsPoolStrategy<E: AsyncExecutor> {
28    None,
29    Shared(Arc<PoolObject<PacketsPool<E::OutputPacket>>>),
30    Distinct {
31        pools_allocator: ObjectsPool<PacketsPool<E::OutputPacket>>,
32    },
33}
34
35pub struct ExecutorDropper {
36    weak_addr: UnsafeCell<WeakExecutorAddress>,
37    context: UnsafeCell<Weak<ExecutionContext>>,
38}
39
40impl ExecutorDropper {
41    pub fn new() -> Self {
42        Self {
43            weak_addr: UnsafeCell::new(WeakExecutorAddress::empty()),
44            context: UnsafeCell::new(Weak::new()),
45        }
46    }
47}
48
49impl Drop for ExecutorDropper {
50    fn drop(&mut self) {
51        let scheduler = unsafe { &*(self.context.get()) };
52        if let Some(context) = scheduler.upgrade() {
53            let address = unsafe { &*(self.weak_addr.get()) };
54            context.dealloc_address(address.clone());
55        }
56    }
57}
58
59pub(crate) type PacketsChannel = AsyncChannel<PacketAny>;
60impl PoolObjectTrait for PacketsChannel {
61    type InitData = usize;
62
63    fn allocate_new(size: &Self::InitData) -> Self {
64        Self::new(*size)
65    }
66
67    fn reset(&mut self) {
68        if self.try_recv().is_some() {
69            panic!("Packets channel not empty!");
70        }
71        self.reopen();
72    }
73}
74
75pub struct ExecutionContext {
76    queues_pool: ObjectsPool<PacketsChannel>,
77    pub(crate) waiting_addresses: Mutex<
78        HashMap<
79            TypeId,
80            DoublePriorityAsyncChannel<(
81                WeakExecutorAddress,
82                Arc<AtomicU64>,
83                Arc<PoolObject<PacketsChannel>>,
84                Arc<dyn Any + Sync + Send + 'static>,
85            )>,
86        >,
87    >,
88    pub(crate) active_executors_counters: DashMap<TypeId, Arc<AtomicU64>>,
89    pub(crate) addresses_map: DashMap<WeakExecutorAddress, Arc<PoolObject<PacketsChannel>>>,
90    pub(crate) packet_pools: DashMap<TypeId, Box<dyn Any + Sync + Send>>,
91    pub(crate) memory_tracker: Arc<MemoryTrackerManager>,
92    pub(crate) start_semaphore: Semaphore,
93    wait_mutex: Mutex<()>,
94    pub(crate) wait_condvar: Condvar,
95}
96
97const MAX_SEMAPHORE_PERMITS: u32 = u32::MAX >> 3;
98
99impl ExecutionContext {
100    pub fn new() -> Arc<Self> {
101        Arc::new(Self {
102            queues_pool: ObjectsPool::new(ADDRESSES_BUFFER_SIZE, 0),
103            waiting_addresses: Mutex::new(HashMap::new()),
104            active_executors_counters: DashMap::new(),
105            addresses_map: DashMap::new(),
106            packet_pools: DashMap::new(),
107            memory_tracker: Arc::new(MemoryTrackerManager::new()),
108            start_semaphore: Semaphore::new(0),
109            wait_mutex: Mutex::new(()),
110            wait_condvar: Condvar::new(),
111        })
112    }
113
114    pub fn register_executor_type<E: AsyncExecutor>(
115        &self,
116        executors_max_count: usize,
117        pool_alloc_mode: PoolAllocMode,
118        pool_init_data: <E::OutputPacket as PoolObjectTrait>::InitData,
119    ) {
120        self.active_executors_counters
121            .insert(TypeId::of::<E>(), Arc::new(AtomicU64::new(0)));
122        self.waiting_addresses
123            .lock()
124            .insert(TypeId::of::<E>(), DoublePriorityAsyncChannel::new(0));
125        self.packet_pools.insert(
126            TypeId::of::<E>(),
127            Box::new(match pool_alloc_mode {
128                PoolAllocMode::None => PacketsPoolStrategy::<E>::None,
129                PoolAllocMode::Shared { capacity } => {
130                    PacketsPoolStrategy::<E>::Shared(Arc::new(PoolObject::new_simple(
131                        PacketsPool::new(capacity, pool_init_data, &self.memory_tracker),
132                    )))
133                }
134                PoolAllocMode::Distinct { capacity } => PacketsPoolStrategy::<E>::Distinct {
135                    pools_allocator: ObjectsPool::new(
136                        executors_max_count,
137                        (capacity, pool_init_data, self.memory_tracker.clone()),
138                    ),
139                },
140            }),
141        );
142    }
143
144    pub fn get_allocated_executors(&self, executor_type_id: &TypeId) -> u64 {
145        self.active_executors_counters
146            .get(executor_type_id)
147            .unwrap()
148            .load(Ordering::SeqCst)
149    }
150
151    pub fn register_executors_batch(
152        self: &Arc<Self>,
153        executors: Vec<ExecutorAddress>,
154        priority: usize,
155    ) {
156        let mut waiting_addresses = self.waiting_addresses.lock();
157
158        for executor in executors {
159            unsafe {
160                *(executor.executor_keeper.context.get()) = Arc::downgrade(self);
161                *(executor.executor_keeper.weak_addr.get()) = executor.to_weak();
162            }
163
164            let queue = Arc::new(self.queues_pool.alloc_object_force());
165
166            let old_val = self.addresses_map.insert(executor.to_weak(), queue.clone());
167
168            let counter = self
169                .active_executors_counters
170                .get(&executor.executor_type_id)
171                .unwrap()
172                .clone();
173
174            counter.fetch_add(1, Ordering::SeqCst);
175            assert!(old_val.is_none());
176
177            waiting_addresses
178                .get_mut(&executor.executor_type_id)
179                .unwrap()
180                .send_with_priority(
181                    (executor.to_weak(), counter, queue, executor.init_data),
182                    priority,
183                    false,
184                );
185        }
186    }
187
188    pub(crate) fn add_input_packet(&self, addr: ExecutorAddress, packet: PacketAny) {
189        self.addresses_map
190            .get(&addr.to_weak())
191            .unwrap()
192            .send(packet, false);
193    }
194
195    pub fn send_packet<T: PacketTrait>(&self, addr: ExecutorAddress, packet: Packet<T>) {
196        self.memory_tracker.add_queue_packet(packet.deref());
197        self.addresses_map
198            .get(&addr.to_weak())
199            .unwrap()
200            .send(packet.upcast(), false);
201    }
202
203    pub(crate) async fn allocate_pool<E: AsyncExecutor>(
204        &self,
205        force: bool,
206    ) -> Option<Arc<PoolObject<PacketsPool<E::OutputPacket>>>> {
207        match self
208            .packet_pools
209            .get(&TypeId::of::<E>())
210            .unwrap()
211            .downcast_ref::<PacketsPoolStrategy<E>>()
212            .unwrap()
213        {
214            PacketsPoolStrategy::None => None,
215            PacketsPoolStrategy::Shared(pool) => Some(pool.clone()),
216            PacketsPoolStrategy::Distinct { pools_allocator } => Some(Arc::new(if force {
217                pools_allocator.alloc_object_force()
218            } else {
219                pools_allocator.alloc_object().await
220            })),
221        }
222    }
223
224    fn dealloc_address(&self, addr: WeakExecutorAddress) {
225        let channel = self.addresses_map.remove(&addr).unwrap();
226        channel.1.release();
227    }
228
229    pub fn start(&self) {
230        self.start_semaphore
231            .add_permits(MAX_SEMAPHORE_PERMITS as usize);
232    }
233
234    pub fn wait_for_completion<E: AsyncExecutor>(&self, _handle: ExecutorsHandle<E>) {
235        let mut wait_mutex = self.wait_mutex.lock();
236        let counter = self
237            .active_executors_counters
238            .get(&TypeId::of::<E>())
239            .unwrap()
240            .value()
241            .clone();
242        loop {
243            // crate::log_info!(
244            //     "Waiting for {} {}",
245            //     std::any::type_name::<E>(),
246            //     counter.load(Ordering::Relaxed)
247            // );
248            if counter.load(Ordering::Relaxed) == 0 {
249                return;
250            }
251            self.wait_condvar
252                .wait_for(&mut wait_mutex, Duration::from_millis(100));
253        }
254    }
255
256    pub fn get_pending_executors_count<E: AsyncExecutor>(
257        &self,
258        _handle: ExecutorsHandle<E>,
259    ) -> u64 {
260        self.get_allocated_executors(&TypeId::of::<E>())
261    }
262
263    pub fn join_all(&self) {
264        let addresses = self.waiting_addresses.lock();
265        addresses.iter().for_each(|addr| addr.1.release());
266        drop(addresses);
267
268        let mut wait_mutex = self.wait_mutex.lock();
269        loop {
270            self.wait_condvar
271                .wait_for(&mut wait_mutex, Duration::from_millis(100));
272            if self
273                .start_semaphore
274                .try_acquire_many(MAX_SEMAPHORE_PERMITS)
275                .is_ok()
276            {
277                break;
278            }
279        }
280    }
281}