Skip to main content

grafeo_core/execution/parallel/
scheduler.rs

1//! Morsel scheduler with work-stealing for parallel execution.
2//!
3//! The scheduler distributes morsels to worker threads using a work-stealing
4//! strategy: workers try the global queue, then steal from other workers.
5//!
6//! # NUMA Awareness
7//!
8//! The scheduler supports NUMA-aware work stealing. Workers are assigned to
9//! NUMA nodes and prefer to steal from workers on the same node to minimize
10//! cross-node memory access latency.
11//!
12//! On systems without explicit NUMA support, workers are assigned to virtual
13//! nodes based on their ID, approximating locality through ID proximity.
14
15use super::morsel::Morsel;
16use crossbeam::deque::{Injector, Steal, Stealer, Worker};
17use parking_lot::Mutex;
18use std::sync::Arc;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20
21/// NUMA node identifier.
22pub type NumaNode = usize;
23
24/// Configuration for NUMA-aware scheduling.
25#[derive(Debug, Clone)]
26pub struct NumaConfig {
27    /// Number of NUMA nodes.
28    pub num_nodes: usize,
29    /// Workers per NUMA node.
30    pub workers_per_node: usize,
31}
32
33impl Default for NumaConfig {
34    fn default() -> Self {
35        // Default: assume uniform memory architecture (1 node)
36        Self {
37            num_nodes: 1,
38            workers_per_node: usize::MAX,
39        }
40    }
41}
42
43impl NumaConfig {
44    /// Creates a config for a specific NUMA topology.
45    #[must_use]
46    pub fn with_topology(num_nodes: usize, workers_per_node: usize) -> Self {
47        Self {
48            num_nodes,
49            workers_per_node,
50        }
51    }
52
53    /// Auto-detect NUMA topology (approximation based on worker count).
54    ///
55    /// Heuristic: assume 2 NUMA nodes for > 8 cores, 1 otherwise.
56    #[must_use]
57    pub fn auto_detect(num_workers: usize) -> Self {
58        if num_workers > 8 {
59            // Assume 2 NUMA nodes on larger systems
60            Self {
61                num_nodes: 2,
62                workers_per_node: (num_workers + 1) / 2,
63            }
64        } else {
65            Self::default()
66        }
67    }
68
69    /// Returns the NUMA node for a worker ID.
70    #[must_use]
71    pub fn worker_node(&self, worker_id: usize) -> NumaNode {
72        if self.workers_per_node == usize::MAX {
73            0
74        } else {
75            worker_id / self.workers_per_node
76        }
77    }
78}
79
80/// Work-stealing morsel scheduler.
81///
82/// Distributes morsels to worker threads efficiently:
83/// 1. Workers check the global injector queue
84/// 2. If empty, steal from other workers via stealers
85///
86/// Supports NUMA-aware stealing to minimize cross-node memory access.
87pub struct MorselScheduler {
88    /// Number of worker threads.
89    num_workers: usize,
90    /// Global queue for morsel distribution.
91    global_queue: Injector<Morsel>,
92    /// Stealers for work-stealing (one per worker).
93    stealers: Mutex<Vec<Stealer<Morsel>>>,
94    /// Count of morsels still being processed.
95    active_morsels: AtomicUsize,
96    /// Total morsels submitted.
97    total_submitted: AtomicUsize,
98    /// Whether submission is complete.
99    submission_done: AtomicBool,
100    /// Whether all work is done.
101    done: AtomicBool,
102    /// NUMA configuration for locality-aware stealing.
103    numa_config: NumaConfig,
104}
105
106impl MorselScheduler {
107    /// Creates a new scheduler for the given number of workers.
108    #[must_use]
109    pub fn new(num_workers: usize) -> Self {
110        Self::with_numa_config(num_workers, NumaConfig::auto_detect(num_workers))
111    }
112
113    /// Creates a scheduler with explicit NUMA configuration.
114    #[must_use]
115    pub fn with_numa_config(num_workers: usize, numa_config: NumaConfig) -> Self {
116        Self {
117            num_workers,
118            global_queue: Injector::new(),
119            stealers: Mutex::new(Vec::with_capacity(num_workers)),
120            active_morsels: AtomicUsize::new(0),
121            total_submitted: AtomicUsize::new(0),
122            submission_done: AtomicBool::new(false),
123            done: AtomicBool::new(false),
124            numa_config,
125        }
126    }
127
128    /// Returns the number of workers.
129    #[must_use]
130    pub fn num_workers(&self) -> usize {
131        self.num_workers
132    }
133
134    /// Submits a morsel to the global queue.
135    pub fn submit(&self, morsel: Morsel) {
136        self.global_queue.push(morsel);
137        self.active_morsels.fetch_add(1, Ordering::Relaxed);
138        self.total_submitted.fetch_add(1, Ordering::Relaxed);
139    }
140
141    /// Submits multiple morsels to the global queue.
142    pub fn submit_batch(&self, morsels: Vec<Morsel>) {
143        let count = morsels.len();
144        for morsel in morsels {
145            self.global_queue.push(morsel);
146        }
147        self.active_morsels.fetch_add(count, Ordering::Relaxed);
148        self.total_submitted.fetch_add(count, Ordering::Relaxed);
149    }
150
151    /// Signals that no more morsels will be submitted.
152    pub fn finish_submission(&self) {
153        self.submission_done.store(true, Ordering::Release);
154        // Check if all work is already done
155        if self.active_morsels.load(Ordering::Acquire) == 0 {
156            self.done.store(true, Ordering::Release);
157        }
158    }
159
160    /// Registers a worker's stealer for work-stealing.
161    ///
162    /// Returns the worker_id assigned.
163    pub fn register_worker(&self, stealer: Stealer<Morsel>) -> usize {
164        let mut stealers = self.stealers.lock();
165        let worker_id = stealers.len();
166        stealers.push(stealer);
167        worker_id
168    }
169
170    /// Gets work from the global queue.
171    pub fn get_global_work(&self) -> Option<Morsel> {
172        loop {
173            match self.global_queue.steal() {
174                Steal::Success(morsel) => return Some(morsel),
175                Steal::Empty => return None,
176                Steal::Retry => continue,
177            }
178        }
179    }
180
181    /// Tries to steal work from other workers.
182    ///
183    /// Uses NUMA-aware stealing: prefers workers on the same NUMA node
184    /// to minimize cross-node memory access latency.
185    pub fn steal_work(&self, my_id: usize) -> Option<Morsel> {
186        let stealers = self.stealers.lock();
187        let num_stealers = stealers.len();
188
189        if num_stealers <= 1 {
190            return None;
191        }
192
193        // Get my NUMA node
194        let my_node = self.numa_config.worker_node(my_id);
195
196        // Phase 1: Try to steal from workers on the same NUMA node
197        for i in 1..num_stealers {
198            let victim = (my_id + i) % num_stealers;
199            let victim_node = self.numa_config.worker_node(victim);
200
201            // Skip workers on different nodes in first pass
202            if victim_node != my_node {
203                continue;
204            }
205
206            if let Some(morsel) = Self::try_steal_from(&stealers[victim]) {
207                return Some(morsel);
208            }
209        }
210
211        // Phase 2: Try workers on other NUMA nodes (cross-node stealing)
212        for i in 1..num_stealers {
213            let victim = (my_id + i) % num_stealers;
214            let victim_node = self.numa_config.worker_node(victim);
215
216            // Only try workers on different nodes now
217            if victim_node == my_node {
218                continue;
219            }
220
221            if let Some(morsel) = Self::try_steal_from(&stealers[victim]) {
222                return Some(morsel);
223            }
224        }
225
226        None
227    }
228
229    /// Attempts to steal from a single stealer.
230    fn try_steal_from(stealer: &Stealer<Morsel>) -> Option<Morsel> {
231        loop {
232            match stealer.steal() {
233                Steal::Success(morsel) => return Some(morsel),
234                Steal::Empty => return None,
235                Steal::Retry => continue,
236            }
237        }
238    }
239
240    /// Returns the NUMA node for a worker.
241    #[must_use]
242    pub fn worker_node(&self, worker_id: usize) -> NumaNode {
243        self.numa_config.worker_node(worker_id)
244    }
245
246    /// Marks a morsel as completed.
247    ///
248    /// Must be called after processing each morsel.
249    pub fn complete_morsel(&self) {
250        let prev = self.active_morsels.fetch_sub(1, Ordering::Release);
251        if prev == 1 && self.submission_done.load(Ordering::Acquire) {
252            self.done.store(true, Ordering::Release);
253        }
254    }
255
256    /// Returns whether all work is done.
257    #[must_use]
258    pub fn is_done(&self) -> bool {
259        self.done.load(Ordering::Acquire)
260    }
261
262    /// Returns whether submission is complete.
263    #[must_use]
264    pub fn is_submission_done(&self) -> bool {
265        self.submission_done.load(Ordering::Acquire)
266    }
267
268    /// Returns the number of active (in-progress) morsels.
269    #[must_use]
270    pub fn active_count(&self) -> usize {
271        self.active_morsels.load(Ordering::Relaxed)
272    }
273
274    /// Returns the total number of morsels submitted.
275    #[must_use]
276    pub fn total_submitted(&self) -> usize {
277        self.total_submitted.load(Ordering::Relaxed)
278    }
279}
280
281impl std::fmt::Debug for MorselScheduler {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("MorselScheduler")
284            .field("num_workers", &self.num_workers)
285            .field(
286                "active_morsels",
287                &self.active_morsels.load(Ordering::Relaxed),
288            )
289            .field(
290                "total_submitted",
291                &self.total_submitted.load(Ordering::Relaxed),
292            )
293            .field(
294                "submission_done",
295                &self.submission_done.load(Ordering::Relaxed),
296            )
297            .field("done", &self.done.load(Ordering::Relaxed))
298            .finish()
299    }
300}
301
302/// Handle for a worker to interact with the scheduler.
303///
304/// Provides a simpler API for workers with integrated work-stealing.
305pub struct WorkerHandle {
306    scheduler: Arc<MorselScheduler>,
307    worker_id: usize,
308    local_queue: Worker<Morsel>,
309}
310
311impl WorkerHandle {
312    /// Creates a new worker handle and registers with the scheduler.
313    #[must_use]
314    pub fn new(scheduler: Arc<MorselScheduler>) -> Self {
315        let local_queue = Worker::new_fifo();
316        let worker_id = scheduler.register_worker(local_queue.stealer());
317        Self {
318            scheduler,
319            worker_id,
320            local_queue,
321        }
322    }
323
324    /// Gets the next morsel to process.
325    ///
326    /// Tries: local queue -> global queue -> steal from others
327    pub fn get_work(&self) -> Option<Morsel> {
328        // Try local queue first
329        if let Some(morsel) = self.local_queue.pop() {
330            return Some(morsel);
331        }
332
333        // Try global queue
334        if let Some(morsel) = self.scheduler.get_global_work() {
335            return Some(morsel);
336        }
337
338        // Try stealing from others
339        if let Some(morsel) = self.scheduler.steal_work(self.worker_id) {
340            return Some(morsel);
341        }
342
343        // Check if we're done
344        if self.scheduler.is_submission_done() && self.scheduler.active_count() == 0 {
345            return None;
346        }
347
348        None
349    }
350
351    /// Pushes a morsel to this worker's local queue.
352    pub fn push_local(&self, morsel: Morsel) {
353        self.local_queue.push(morsel);
354        self.scheduler
355            .active_morsels
356            .fetch_add(1, Ordering::Relaxed);
357    }
358
359    /// Marks the current morsel as complete.
360    pub fn complete_morsel(&self) {
361        self.scheduler.complete_morsel();
362    }
363
364    /// Returns the worker ID.
365    #[must_use]
366    pub fn worker_id(&self) -> usize {
367        self.worker_id
368    }
369
370    /// Returns whether all work is done.
371    #[must_use]
372    pub fn is_done(&self) -> bool {
373        self.scheduler.is_done()
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_scheduler_creation() {
383        let scheduler = MorselScheduler::new(4);
384        assert_eq!(scheduler.num_workers(), 4);
385        assert_eq!(scheduler.active_count(), 0);
386        assert!(!scheduler.is_done());
387    }
388
389    #[test]
390    fn test_submit_and_get_work() {
391        let scheduler = Arc::new(MorselScheduler::new(2));
392
393        scheduler.submit(Morsel::new(0, 0, 0, 1000));
394        scheduler.submit(Morsel::new(1, 0, 1000, 2000));
395        assert_eq!(scheduler.total_submitted(), 2);
396        assert_eq!(scheduler.active_count(), 2);
397
398        // Get work from global queue
399        let morsel = scheduler.get_global_work().unwrap();
400        assert_eq!(morsel.id, 0);
401
402        // Complete the morsel
403        scheduler.complete_morsel();
404        assert_eq!(scheduler.active_count(), 1);
405
406        // Get more work
407        let morsel = scheduler.get_global_work().unwrap();
408        assert_eq!(morsel.id, 1);
409        scheduler.complete_morsel();
410
411        scheduler.finish_submission();
412        assert!(scheduler.is_done());
413    }
414
415    #[test]
416    fn test_submit_batch() {
417        let scheduler = MorselScheduler::new(4);
418
419        let morsels = vec![
420            Morsel::new(0, 0, 0, 100),
421            Morsel::new(1, 0, 100, 200),
422            Morsel::new(2, 0, 200, 300),
423        ];
424        scheduler.submit_batch(morsels);
425
426        assert_eq!(scheduler.total_submitted(), 3);
427        assert_eq!(scheduler.active_count(), 3);
428    }
429
430    #[test]
431    fn test_worker_handle() {
432        let scheduler = Arc::new(MorselScheduler::new(2));
433
434        let handle = WorkerHandle::new(Arc::clone(&scheduler));
435        assert_eq!(handle.worker_id(), 0);
436        assert!(!handle.is_done());
437
438        scheduler.submit(Morsel::new(0, 0, 0, 100));
439
440        let morsel = handle.get_work().unwrap();
441        assert_eq!(morsel.id, 0);
442
443        handle.complete_morsel();
444        scheduler.finish_submission();
445
446        assert!(handle.is_done());
447    }
448
449    #[test]
450    fn test_worker_local_queue() {
451        let scheduler = Arc::new(MorselScheduler::new(2));
452        let handle = WorkerHandle::new(Arc::clone(&scheduler));
453
454        // Push to local queue
455        handle.push_local(Morsel::new(0, 0, 0, 100));
456
457        // Should get it from local queue
458        let morsel = handle.get_work().unwrap();
459        assert_eq!(morsel.id, 0);
460    }
461
462    #[test]
463    fn test_work_stealing() {
464        let scheduler = Arc::new(MorselScheduler::new(2));
465
466        // Create two workers
467        let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
468        let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
469
470        // Push multiple items to worker 1's local queue
471        for i in 0..5 {
472            handle1.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
473        }
474
475        // Worker 1 takes one
476        let _ = handle1.get_work().unwrap();
477
478        // Worker 2 should be able to steal
479        let stolen = handle2.get_work();
480        assert!(stolen.is_some());
481    }
482
483    #[test]
484    fn test_concurrent_workers() {
485        use std::thread;
486
487        let scheduler = Arc::new(MorselScheduler::new(4));
488        let total_morsels = 100;
489
490        // Submit morsels
491        for i in 0..total_morsels {
492            scheduler.submit(Morsel::new(i, 0, i * 100, (i + 1) * 100));
493        }
494        scheduler.finish_submission();
495
496        // Spawn workers
497        let completed = Arc::new(AtomicUsize::new(0));
498        let mut handles = Vec::new();
499
500        for _ in 0..4 {
501            let sched = Arc::clone(&scheduler);
502            let completed = Arc::clone(&completed);
503
504            handles.push(thread::spawn(move || {
505                let handle = WorkerHandle::new(sched);
506                let mut count = 0;
507                while let Some(_morsel) = handle.get_work() {
508                    count += 1;
509                    handle.complete_morsel();
510                }
511                completed.fetch_add(count, Ordering::Relaxed);
512            }));
513        }
514
515        for handle in handles {
516            handle.join().unwrap();
517        }
518
519        assert_eq!(completed.load(Ordering::Relaxed), total_morsels);
520    }
521
522    #[test]
523    fn test_numa_config_default() {
524        let config = NumaConfig::default();
525        assert_eq!(config.num_nodes, 1);
526        assert_eq!(config.worker_node(0), 0);
527        assert_eq!(config.worker_node(100), 0);
528    }
529
530    #[test]
531    fn test_numa_config_auto_detect() {
532        // Small system: 1 NUMA node
533        let config = NumaConfig::auto_detect(4);
534        assert_eq!(config.num_nodes, 1);
535
536        // Larger system: 2 NUMA nodes
537        let config = NumaConfig::auto_detect(16);
538        assert_eq!(config.num_nodes, 2);
539        assert_eq!(config.workers_per_node, 8);
540    }
541
542    #[test]
543    fn test_numa_config_worker_node() {
544        let config = NumaConfig::with_topology(2, 4);
545
546        // First 4 workers on node 0
547        assert_eq!(config.worker_node(0), 0);
548        assert_eq!(config.worker_node(1), 0);
549        assert_eq!(config.worker_node(2), 0);
550        assert_eq!(config.worker_node(3), 0);
551
552        // Next 4 workers on node 1
553        assert_eq!(config.worker_node(4), 1);
554        assert_eq!(config.worker_node(5), 1);
555        assert_eq!(config.worker_node(6), 1);
556        assert_eq!(config.worker_node(7), 1);
557    }
558
559    #[test]
560    fn test_scheduler_with_numa_config() {
561        let config = NumaConfig::with_topology(2, 2);
562        let scheduler = MorselScheduler::with_numa_config(4, config);
563
564        assert_eq!(scheduler.num_workers(), 4);
565        assert_eq!(scheduler.worker_node(0), 0);
566        assert_eq!(scheduler.worker_node(1), 0);
567        assert_eq!(scheduler.worker_node(2), 1);
568        assert_eq!(scheduler.worker_node(3), 1);
569    }
570
571    #[test]
572    fn test_numa_aware_stealing() {
573        // Create scheduler with 2 NUMA nodes, 2 workers each
574        let config = NumaConfig::with_topology(2, 2);
575        let scheduler = Arc::new(MorselScheduler::with_numa_config(4, config));
576
577        // Create 4 workers (0,1 on node 0; 2,3 on node 1)
578        let handle0 = WorkerHandle::new(Arc::clone(&scheduler));
579        let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
580        let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
581        let _handle3 = WorkerHandle::new(Arc::clone(&scheduler));
582
583        // Worker 0 has work
584        for i in 0..10 {
585            handle0.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
586        }
587
588        // Worker 1 (same NUMA node) should be able to steal
589        let stolen1 = handle1.get_work();
590        assert!(stolen1.is_some(), "Same-node worker should steal first");
591
592        // Worker 2 (different NUMA node) can also steal
593        let stolen2 = handle2.get_work();
594        assert!(stolen2.is_some(), "Cross-node worker can steal");
595    }
596}