grafeo_core/execution/parallel/
scheduler.rs1use super::morsel::Morsel;
7use crossbeam::deque::{Injector, Steal, Stealer, Worker};
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11
12pub struct MorselScheduler {
18 num_workers: usize,
20 global_queue: Injector<Morsel>,
22 stealers: Mutex<Vec<Stealer<Morsel>>>,
24 active_morsels: AtomicUsize,
26 total_submitted: AtomicUsize,
28 submission_done: AtomicBool,
30 done: AtomicBool,
32}
33
34impl MorselScheduler {
35 #[must_use]
37 pub fn new(num_workers: usize) -> Self {
38 Self {
39 num_workers,
40 global_queue: Injector::new(),
41 stealers: Mutex::new(Vec::with_capacity(num_workers)),
42 active_morsels: AtomicUsize::new(0),
43 total_submitted: AtomicUsize::new(0),
44 submission_done: AtomicBool::new(false),
45 done: AtomicBool::new(false),
46 }
47 }
48
49 #[must_use]
51 pub fn num_workers(&self) -> usize {
52 self.num_workers
53 }
54
55 pub fn submit(&self, morsel: Morsel) {
57 self.global_queue.push(morsel);
58 self.active_morsels.fetch_add(1, Ordering::Relaxed);
59 self.total_submitted.fetch_add(1, Ordering::Relaxed);
60 }
61
62 pub fn submit_batch(&self, morsels: Vec<Morsel>) {
64 let count = morsels.len();
65 for morsel in morsels {
66 self.global_queue.push(morsel);
67 }
68 self.active_morsels.fetch_add(count, Ordering::Relaxed);
69 self.total_submitted.fetch_add(count, Ordering::Relaxed);
70 }
71
72 pub fn finish_submission(&self) {
74 self.submission_done.store(true, Ordering::Release);
75 if self.active_morsels.load(Ordering::Acquire) == 0 {
77 self.done.store(true, Ordering::Release);
78 }
79 }
80
81 pub fn register_worker(&self, stealer: Stealer<Morsel>) -> usize {
85 let mut stealers = self.stealers.lock();
86 let worker_id = stealers.len();
87 stealers.push(stealer);
88 worker_id
89 }
90
91 pub fn get_global_work(&self) -> Option<Morsel> {
93 loop {
94 match self.global_queue.steal() {
95 Steal::Success(morsel) => return Some(morsel),
96 Steal::Empty => return None,
97 Steal::Retry => continue,
98 }
99 }
100 }
101
102 pub fn steal_work(&self, my_id: usize) -> Option<Morsel> {
104 let stealers = self.stealers.lock();
105 let num_stealers = stealers.len();
106
107 if num_stealers <= 1 {
108 return None;
109 }
110
111 for i in 1..num_stealers {
113 let victim = (my_id + i) % num_stealers;
114 loop {
115 match stealers[victim].steal() {
116 Steal::Success(morsel) => return Some(morsel),
117 Steal::Empty => break,
118 Steal::Retry => continue,
119 }
120 }
121 }
122 None
123 }
124
125 pub fn complete_morsel(&self) {
129 let prev = self.active_morsels.fetch_sub(1, Ordering::Release);
130 if prev == 1 && self.submission_done.load(Ordering::Acquire) {
131 self.done.store(true, Ordering::Release);
132 }
133 }
134
135 #[must_use]
137 pub fn is_done(&self) -> bool {
138 self.done.load(Ordering::Acquire)
139 }
140
141 #[must_use]
143 pub fn is_submission_done(&self) -> bool {
144 self.submission_done.load(Ordering::Acquire)
145 }
146
147 #[must_use]
149 pub fn active_count(&self) -> usize {
150 self.active_morsels.load(Ordering::Relaxed)
151 }
152
153 #[must_use]
155 pub fn total_submitted(&self) -> usize {
156 self.total_submitted.load(Ordering::Relaxed)
157 }
158}
159
160impl std::fmt::Debug for MorselScheduler {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 f.debug_struct("MorselScheduler")
163 .field("num_workers", &self.num_workers)
164 .field(
165 "active_morsels",
166 &self.active_morsels.load(Ordering::Relaxed),
167 )
168 .field(
169 "total_submitted",
170 &self.total_submitted.load(Ordering::Relaxed),
171 )
172 .field(
173 "submission_done",
174 &self.submission_done.load(Ordering::Relaxed),
175 )
176 .field("done", &self.done.load(Ordering::Relaxed))
177 .finish()
178 }
179}
180
181pub struct WorkerHandle {
185 scheduler: Arc<MorselScheduler>,
186 worker_id: usize,
187 local_queue: Worker<Morsel>,
188}
189
190impl WorkerHandle {
191 #[must_use]
193 pub fn new(scheduler: Arc<MorselScheduler>) -> Self {
194 let local_queue = Worker::new_fifo();
195 let worker_id = scheduler.register_worker(local_queue.stealer());
196 Self {
197 scheduler,
198 worker_id,
199 local_queue,
200 }
201 }
202
203 pub fn get_work(&self) -> Option<Morsel> {
207 if let Some(morsel) = self.local_queue.pop() {
209 return Some(morsel);
210 }
211
212 if let Some(morsel) = self.scheduler.get_global_work() {
214 return Some(morsel);
215 }
216
217 if let Some(morsel) = self.scheduler.steal_work(self.worker_id) {
219 return Some(morsel);
220 }
221
222 if self.scheduler.is_submission_done() && self.scheduler.active_count() == 0 {
224 return None;
225 }
226
227 None
228 }
229
230 pub fn push_local(&self, morsel: Morsel) {
232 self.local_queue.push(morsel);
233 self.scheduler
234 .active_morsels
235 .fetch_add(1, Ordering::Relaxed);
236 }
237
238 pub fn complete_morsel(&self) {
240 self.scheduler.complete_morsel();
241 }
242
243 #[must_use]
245 pub fn worker_id(&self) -> usize {
246 self.worker_id
247 }
248
249 #[must_use]
251 pub fn is_done(&self) -> bool {
252 self.scheduler.is_done()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_scheduler_creation() {
262 let scheduler = MorselScheduler::new(4);
263 assert_eq!(scheduler.num_workers(), 4);
264 assert_eq!(scheduler.active_count(), 0);
265 assert!(!scheduler.is_done());
266 }
267
268 #[test]
269 fn test_submit_and_get_work() {
270 let scheduler = Arc::new(MorselScheduler::new(2));
271
272 scheduler.submit(Morsel::new(0, 0, 0, 1000));
273 scheduler.submit(Morsel::new(1, 0, 1000, 2000));
274 assert_eq!(scheduler.total_submitted(), 2);
275 assert_eq!(scheduler.active_count(), 2);
276
277 let morsel = scheduler.get_global_work().unwrap();
279 assert_eq!(morsel.id, 0);
280
281 scheduler.complete_morsel();
283 assert_eq!(scheduler.active_count(), 1);
284
285 let morsel = scheduler.get_global_work().unwrap();
287 assert_eq!(morsel.id, 1);
288 scheduler.complete_morsel();
289
290 scheduler.finish_submission();
291 assert!(scheduler.is_done());
292 }
293
294 #[test]
295 fn test_submit_batch() {
296 let scheduler = MorselScheduler::new(4);
297
298 let morsels = vec![
299 Morsel::new(0, 0, 0, 100),
300 Morsel::new(1, 0, 100, 200),
301 Morsel::new(2, 0, 200, 300),
302 ];
303 scheduler.submit_batch(morsels);
304
305 assert_eq!(scheduler.total_submitted(), 3);
306 assert_eq!(scheduler.active_count(), 3);
307 }
308
309 #[test]
310 fn test_worker_handle() {
311 let scheduler = Arc::new(MorselScheduler::new(2));
312
313 let handle = WorkerHandle::new(Arc::clone(&scheduler));
314 assert_eq!(handle.worker_id(), 0);
315 assert!(!handle.is_done());
316
317 scheduler.submit(Morsel::new(0, 0, 0, 100));
318
319 let morsel = handle.get_work().unwrap();
320 assert_eq!(morsel.id, 0);
321
322 handle.complete_morsel();
323 scheduler.finish_submission();
324
325 assert!(handle.is_done());
326 }
327
328 #[test]
329 fn test_worker_local_queue() {
330 let scheduler = Arc::new(MorselScheduler::new(2));
331 let handle = WorkerHandle::new(Arc::clone(&scheduler));
332
333 handle.push_local(Morsel::new(0, 0, 0, 100));
335
336 let morsel = handle.get_work().unwrap();
338 assert_eq!(morsel.id, 0);
339 }
340
341 #[test]
342 fn test_work_stealing() {
343 let scheduler = Arc::new(MorselScheduler::new(2));
344
345 let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
347 let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
348
349 for i in 0..5 {
351 handle1.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
352 }
353
354 let _ = handle1.get_work().unwrap();
356
357 let stolen = handle2.get_work();
359 assert!(stolen.is_some());
360 }
361
362 #[test]
363 fn test_concurrent_workers() {
364 use std::thread;
365
366 let scheduler = Arc::new(MorselScheduler::new(4));
367 let total_morsels = 100;
368
369 for i in 0..total_morsels {
371 scheduler.submit(Morsel::new(i, 0, i * 100, (i + 1) * 100));
372 }
373 scheduler.finish_submission();
374
375 let completed = Arc::new(AtomicUsize::new(0));
377 let mut handles = Vec::new();
378
379 for _ in 0..4 {
380 let sched = Arc::clone(&scheduler);
381 let completed = Arc::clone(&completed);
382
383 handles.push(thread::spawn(move || {
384 let handle = WorkerHandle::new(sched);
385 let mut count = 0;
386 while let Some(_morsel) = handle.get_work() {
387 count += 1;
388 handle.complete_morsel();
389 }
390 completed.fetch_add(count, Ordering::Relaxed);
391 }));
392 }
393
394 for handle in handles {
395 handle.join().unwrap();
396 }
397
398 assert_eq!(completed.load(Ordering::Relaxed), total_morsels);
399 }
400}