grafeo_core/execution/parallel/
scheduler.rs1use super::morsel::Morsel;
16use crossbeam::deque::{Injector, Steal, Stealer, Worker};
17use parking_lot::Mutex;
18use std::sync::Arc;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20
21pub type NumaNode = usize;
23
24#[derive(Debug, Clone)]
26pub struct NumaConfig {
27 pub num_nodes: usize,
29 pub workers_per_node: usize,
31}
32
33impl Default for NumaConfig {
34 fn default() -> Self {
35 Self {
37 num_nodes: 1,
38 workers_per_node: usize::MAX,
39 }
40 }
41}
42
43impl NumaConfig {
44 #[must_use]
46 pub fn with_topology(num_nodes: usize, workers_per_node: usize) -> Self {
47 Self {
48 num_nodes,
49 workers_per_node,
50 }
51 }
52
53 #[must_use]
57 pub fn auto_detect(num_workers: usize) -> Self {
58 if num_workers > 8 {
59 Self {
61 num_nodes: 2,
62 workers_per_node: (num_workers + 1) / 2,
63 }
64 } else {
65 Self::default()
66 }
67 }
68
69 #[must_use]
71 pub fn worker_node(&self, worker_id: usize) -> NumaNode {
72 if self.workers_per_node == usize::MAX {
73 0
74 } else {
75 worker_id / self.workers_per_node
76 }
77 }
78}
79
80pub struct MorselScheduler {
88 num_workers: usize,
90 global_queue: Injector<Morsel>,
92 stealers: Mutex<Vec<Stealer<Morsel>>>,
94 active_morsels: AtomicUsize,
96 total_submitted: AtomicUsize,
98 submission_done: AtomicBool,
100 done: AtomicBool,
102 numa_config: NumaConfig,
104}
105
106impl MorselScheduler {
107 #[must_use]
109 pub fn new(num_workers: usize) -> Self {
110 Self::with_numa_config(num_workers, NumaConfig::auto_detect(num_workers))
111 }
112
113 #[must_use]
115 pub fn with_numa_config(num_workers: usize, numa_config: NumaConfig) -> Self {
116 Self {
117 num_workers,
118 global_queue: Injector::new(),
119 stealers: Mutex::new(Vec::with_capacity(num_workers)),
120 active_morsels: AtomicUsize::new(0),
121 total_submitted: AtomicUsize::new(0),
122 submission_done: AtomicBool::new(false),
123 done: AtomicBool::new(false),
124 numa_config,
125 }
126 }
127
128 #[must_use]
130 pub fn num_workers(&self) -> usize {
131 self.num_workers
132 }
133
134 pub fn submit(&self, morsel: Morsel) {
136 self.global_queue.push(morsel);
137 self.active_morsels.fetch_add(1, Ordering::Relaxed);
138 self.total_submitted.fetch_add(1, Ordering::Relaxed);
139 }
140
141 pub fn submit_batch(&self, morsels: Vec<Morsel>) {
143 let count = morsels.len();
144 for morsel in morsels {
145 self.global_queue.push(morsel);
146 }
147 self.active_morsels.fetch_add(count, Ordering::Relaxed);
148 self.total_submitted.fetch_add(count, Ordering::Relaxed);
149 }
150
151 pub fn finish_submission(&self) {
153 self.submission_done.store(true, Ordering::Release);
154 if self.active_morsels.load(Ordering::Acquire) == 0 {
156 self.done.store(true, Ordering::Release);
157 }
158 }
159
160 pub fn register_worker(&self, stealer: Stealer<Morsel>) -> usize {
164 let mut stealers = self.stealers.lock();
165 let worker_id = stealers.len();
166 stealers.push(stealer);
167 worker_id
168 }
169
170 pub fn get_global_work(&self) -> Option<Morsel> {
172 loop {
173 match self.global_queue.steal() {
174 Steal::Success(morsel) => return Some(morsel),
175 Steal::Empty => return None,
176 Steal::Retry => continue,
177 }
178 }
179 }
180
181 pub fn steal_work(&self, my_id: usize) -> Option<Morsel> {
186 let stealers = self.stealers.lock();
187 let num_stealers = stealers.len();
188
189 if num_stealers <= 1 {
190 return None;
191 }
192
193 let my_node = self.numa_config.worker_node(my_id);
195
196 for i in 1..num_stealers {
198 let victim = (my_id + i) % num_stealers;
199 let victim_node = self.numa_config.worker_node(victim);
200
201 if victim_node != my_node {
203 continue;
204 }
205
206 if let Some(morsel) = Self::try_steal_from(&stealers[victim]) {
207 return Some(morsel);
208 }
209 }
210
211 for i in 1..num_stealers {
213 let victim = (my_id + i) % num_stealers;
214 let victim_node = self.numa_config.worker_node(victim);
215
216 if victim_node == my_node {
218 continue;
219 }
220
221 if let Some(morsel) = Self::try_steal_from(&stealers[victim]) {
222 return Some(morsel);
223 }
224 }
225
226 None
227 }
228
229 fn try_steal_from(stealer: &Stealer<Morsel>) -> Option<Morsel> {
231 loop {
232 match stealer.steal() {
233 Steal::Success(morsel) => return Some(morsel),
234 Steal::Empty => return None,
235 Steal::Retry => continue,
236 }
237 }
238 }
239
240 #[must_use]
242 pub fn worker_node(&self, worker_id: usize) -> NumaNode {
243 self.numa_config.worker_node(worker_id)
244 }
245
246 pub fn complete_morsel(&self) {
250 let prev = self.active_morsels.fetch_sub(1, Ordering::Release);
251 if prev == 1 && self.submission_done.load(Ordering::Acquire) {
252 self.done.store(true, Ordering::Release);
253 }
254 }
255
256 #[must_use]
258 pub fn is_done(&self) -> bool {
259 self.done.load(Ordering::Acquire)
260 }
261
262 #[must_use]
264 pub fn is_submission_done(&self) -> bool {
265 self.submission_done.load(Ordering::Acquire)
266 }
267
268 #[must_use]
270 pub fn active_count(&self) -> usize {
271 self.active_morsels.load(Ordering::Relaxed)
272 }
273
274 #[must_use]
276 pub fn total_submitted(&self) -> usize {
277 self.total_submitted.load(Ordering::Relaxed)
278 }
279}
280
281impl std::fmt::Debug for MorselScheduler {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("MorselScheduler")
284 .field("num_workers", &self.num_workers)
285 .field(
286 "active_morsels",
287 &self.active_morsels.load(Ordering::Relaxed),
288 )
289 .field(
290 "total_submitted",
291 &self.total_submitted.load(Ordering::Relaxed),
292 )
293 .field(
294 "submission_done",
295 &self.submission_done.load(Ordering::Relaxed),
296 )
297 .field("done", &self.done.load(Ordering::Relaxed))
298 .finish()
299 }
300}
301
302pub struct WorkerHandle {
306 scheduler: Arc<MorselScheduler>,
307 worker_id: usize,
308 local_queue: Worker<Morsel>,
309}
310
311impl WorkerHandle {
312 #[must_use]
314 pub fn new(scheduler: Arc<MorselScheduler>) -> Self {
315 let local_queue = Worker::new_fifo();
316 let worker_id = scheduler.register_worker(local_queue.stealer());
317 Self {
318 scheduler,
319 worker_id,
320 local_queue,
321 }
322 }
323
324 pub fn get_work(&self) -> Option<Morsel> {
328 if let Some(morsel) = self.local_queue.pop() {
330 return Some(morsel);
331 }
332
333 if let Some(morsel) = self.scheduler.get_global_work() {
335 return Some(morsel);
336 }
337
338 if let Some(morsel) = self.scheduler.steal_work(self.worker_id) {
340 return Some(morsel);
341 }
342
343 if self.scheduler.is_submission_done() && self.scheduler.active_count() == 0 {
345 return None;
346 }
347
348 None
349 }
350
351 pub fn push_local(&self, morsel: Morsel) {
353 self.local_queue.push(morsel);
354 self.scheduler
355 .active_morsels
356 .fetch_add(1, Ordering::Relaxed);
357 }
358
359 pub fn complete_morsel(&self) {
361 self.scheduler.complete_morsel();
362 }
363
364 #[must_use]
366 pub fn worker_id(&self) -> usize {
367 self.worker_id
368 }
369
370 #[must_use]
372 pub fn is_done(&self) -> bool {
373 self.scheduler.is_done()
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_scheduler_creation() {
383 let scheduler = MorselScheduler::new(4);
384 assert_eq!(scheduler.num_workers(), 4);
385 assert_eq!(scheduler.active_count(), 0);
386 assert!(!scheduler.is_done());
387 }
388
389 #[test]
390 fn test_submit_and_get_work() {
391 let scheduler = Arc::new(MorselScheduler::new(2));
392
393 scheduler.submit(Morsel::new(0, 0, 0, 1000));
394 scheduler.submit(Morsel::new(1, 0, 1000, 2000));
395 assert_eq!(scheduler.total_submitted(), 2);
396 assert_eq!(scheduler.active_count(), 2);
397
398 let morsel = scheduler.get_global_work().unwrap();
400 assert_eq!(morsel.id, 0);
401
402 scheduler.complete_morsel();
404 assert_eq!(scheduler.active_count(), 1);
405
406 let morsel = scheduler.get_global_work().unwrap();
408 assert_eq!(morsel.id, 1);
409 scheduler.complete_morsel();
410
411 scheduler.finish_submission();
412 assert!(scheduler.is_done());
413 }
414
415 #[test]
416 fn test_submit_batch() {
417 let scheduler = MorselScheduler::new(4);
418
419 let morsels = vec![
420 Morsel::new(0, 0, 0, 100),
421 Morsel::new(1, 0, 100, 200),
422 Morsel::new(2, 0, 200, 300),
423 ];
424 scheduler.submit_batch(morsels);
425
426 assert_eq!(scheduler.total_submitted(), 3);
427 assert_eq!(scheduler.active_count(), 3);
428 }
429
430 #[test]
431 fn test_worker_handle() {
432 let scheduler = Arc::new(MorselScheduler::new(2));
433
434 let handle = WorkerHandle::new(Arc::clone(&scheduler));
435 assert_eq!(handle.worker_id(), 0);
436 assert!(!handle.is_done());
437
438 scheduler.submit(Morsel::new(0, 0, 0, 100));
439
440 let morsel = handle.get_work().unwrap();
441 assert_eq!(morsel.id, 0);
442
443 handle.complete_morsel();
444 scheduler.finish_submission();
445
446 assert!(handle.is_done());
447 }
448
449 #[test]
450 fn test_worker_local_queue() {
451 let scheduler = Arc::new(MorselScheduler::new(2));
452 let handle = WorkerHandle::new(Arc::clone(&scheduler));
453
454 handle.push_local(Morsel::new(0, 0, 0, 100));
456
457 let morsel = handle.get_work().unwrap();
459 assert_eq!(morsel.id, 0);
460 }
461
462 #[test]
463 fn test_work_stealing() {
464 let scheduler = Arc::new(MorselScheduler::new(2));
465
466 let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
468 let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
469
470 for i in 0..5 {
472 handle1.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
473 }
474
475 let _ = handle1.get_work().unwrap();
477
478 let stolen = handle2.get_work();
480 assert!(stolen.is_some());
481 }
482
483 #[test]
484 fn test_concurrent_workers() {
485 use std::thread;
486
487 let scheduler = Arc::new(MorselScheduler::new(4));
488 let total_morsels = 100;
489
490 for i in 0..total_morsels {
492 scheduler.submit(Morsel::new(i, 0, i * 100, (i + 1) * 100));
493 }
494 scheduler.finish_submission();
495
496 let completed = Arc::new(AtomicUsize::new(0));
498 let mut handles = Vec::new();
499
500 for _ in 0..4 {
501 let sched = Arc::clone(&scheduler);
502 let completed = Arc::clone(&completed);
503
504 handles.push(thread::spawn(move || {
505 let handle = WorkerHandle::new(sched);
506 let mut count = 0;
507 while let Some(_morsel) = handle.get_work() {
508 count += 1;
509 handle.complete_morsel();
510 }
511 completed.fetch_add(count, Ordering::Relaxed);
512 }));
513 }
514
515 for handle in handles {
516 handle.join().unwrap();
517 }
518
519 assert_eq!(completed.load(Ordering::Relaxed), total_morsels);
520 }
521
522 #[test]
523 fn test_numa_config_default() {
524 let config = NumaConfig::default();
525 assert_eq!(config.num_nodes, 1);
526 assert_eq!(config.worker_node(0), 0);
527 assert_eq!(config.worker_node(100), 0);
528 }
529
530 #[test]
531 fn test_numa_config_auto_detect() {
532 let config = NumaConfig::auto_detect(4);
534 assert_eq!(config.num_nodes, 1);
535
536 let config = NumaConfig::auto_detect(16);
538 assert_eq!(config.num_nodes, 2);
539 assert_eq!(config.workers_per_node, 8);
540 }
541
542 #[test]
543 fn test_numa_config_worker_node() {
544 let config = NumaConfig::with_topology(2, 4);
545
546 assert_eq!(config.worker_node(0), 0);
548 assert_eq!(config.worker_node(1), 0);
549 assert_eq!(config.worker_node(2), 0);
550 assert_eq!(config.worker_node(3), 0);
551
552 assert_eq!(config.worker_node(4), 1);
554 assert_eq!(config.worker_node(5), 1);
555 assert_eq!(config.worker_node(6), 1);
556 assert_eq!(config.worker_node(7), 1);
557 }
558
559 #[test]
560 fn test_scheduler_with_numa_config() {
561 let config = NumaConfig::with_topology(2, 2);
562 let scheduler = MorselScheduler::with_numa_config(4, config);
563
564 assert_eq!(scheduler.num_workers(), 4);
565 assert_eq!(scheduler.worker_node(0), 0);
566 assert_eq!(scheduler.worker_node(1), 0);
567 assert_eq!(scheduler.worker_node(2), 1);
568 assert_eq!(scheduler.worker_node(3), 1);
569 }
570
571 #[test]
572 fn test_numa_aware_stealing() {
573 let config = NumaConfig::with_topology(2, 2);
575 let scheduler = Arc::new(MorselScheduler::with_numa_config(4, config));
576
577 let handle0 = WorkerHandle::new(Arc::clone(&scheduler));
579 let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
580 let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
581 let _handle3 = WorkerHandle::new(Arc::clone(&scheduler));
582
583 for i in 0..10 {
585 handle0.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
586 }
587
588 let stolen1 = handle1.get_work();
590 assert!(stolen1.is_some(), "Same-node worker should steal first");
591
592 let stolen2 = handle2.get_work();
594 assert!(stolen2.is_some(), "Cross-node worker can steal");
595 }
596}