Skip to main content

grafeo_core/execution/parallel/
scheduler.rs

1//! Morsel scheduler with work-stealing for parallel execution.
2//!
3//! The scheduler distributes morsels to worker threads using a work-stealing
4//! strategy: workers try the global queue, then steal from other workers.
5
6use super::morsel::Morsel;
7use crossbeam::deque::{Injector, Steal, Stealer, Worker};
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11
12/// Work-stealing morsel scheduler.
13///
14/// Distributes morsels to worker threads efficiently:
15/// 1. Workers check the global injector queue
16/// 2. If empty, steal from other workers via stealers
17pub struct MorselScheduler {
18    /// Number of worker threads.
19    num_workers: usize,
20    /// Global queue for morsel distribution.
21    global_queue: Injector<Morsel>,
22    /// Stealers for work-stealing (one per worker).
23    stealers: Mutex<Vec<Stealer<Morsel>>>,
24    /// Count of morsels still being processed.
25    active_morsels: AtomicUsize,
26    /// Total morsels submitted.
27    total_submitted: AtomicUsize,
28    /// Whether submission is complete.
29    submission_done: AtomicBool,
30    /// Whether all work is done.
31    done: AtomicBool,
32}
33
34impl MorselScheduler {
35    /// Creates a new scheduler for the given number of workers.
36    #[must_use]
37    pub fn new(num_workers: usize) -> Self {
38        Self {
39            num_workers,
40            global_queue: Injector::new(),
41            stealers: Mutex::new(Vec::with_capacity(num_workers)),
42            active_morsels: AtomicUsize::new(0),
43            total_submitted: AtomicUsize::new(0),
44            submission_done: AtomicBool::new(false),
45            done: AtomicBool::new(false),
46        }
47    }
48
49    /// Returns the number of workers.
50    #[must_use]
51    pub fn num_workers(&self) -> usize {
52        self.num_workers
53    }
54
55    /// Submits a morsel to the global queue.
56    pub fn submit(&self, morsel: Morsel) {
57        self.global_queue.push(morsel);
58        self.active_morsels.fetch_add(1, Ordering::Relaxed);
59        self.total_submitted.fetch_add(1, Ordering::Relaxed);
60    }
61
62    /// Submits multiple morsels to the global queue.
63    pub fn submit_batch(&self, morsels: Vec<Morsel>) {
64        let count = morsels.len();
65        for morsel in morsels {
66            self.global_queue.push(morsel);
67        }
68        self.active_morsels.fetch_add(count, Ordering::Relaxed);
69        self.total_submitted.fetch_add(count, Ordering::Relaxed);
70    }
71
72    /// Signals that no more morsels will be submitted.
73    pub fn finish_submission(&self) {
74        self.submission_done.store(true, Ordering::Release);
75        // Check if all work is already done
76        if self.active_morsels.load(Ordering::Acquire) == 0 {
77            self.done.store(true, Ordering::Release);
78        }
79    }
80
81    /// Registers a worker's stealer for work-stealing.
82    ///
83    /// Returns the worker_id assigned.
84    pub fn register_worker(&self, stealer: Stealer<Morsel>) -> usize {
85        let mut stealers = self.stealers.lock();
86        let worker_id = stealers.len();
87        stealers.push(stealer);
88        worker_id
89    }
90
91    /// Gets work from the global queue.
92    pub fn get_global_work(&self) -> Option<Morsel> {
93        loop {
94            match self.global_queue.steal() {
95                Steal::Success(morsel) => return Some(morsel),
96                Steal::Empty => return None,
97                Steal::Retry => continue,
98            }
99        }
100    }
101
102    /// Tries to steal work from other workers.
103    pub fn steal_work(&self, my_id: usize) -> Option<Morsel> {
104        let stealers = self.stealers.lock();
105        let num_stealers = stealers.len();
106
107        if num_stealers <= 1 {
108            return None;
109        }
110
111        // Try each stealer in round-robin fashion starting from the next one
112        for i in 1..num_stealers {
113            let victim = (my_id + i) % num_stealers;
114            loop {
115                match stealers[victim].steal() {
116                    Steal::Success(morsel) => return Some(morsel),
117                    Steal::Empty => break,
118                    Steal::Retry => continue,
119                }
120            }
121        }
122        None
123    }
124
125    /// Marks a morsel as completed.
126    ///
127    /// Must be called after processing each morsel.
128    pub fn complete_morsel(&self) {
129        let prev = self.active_morsels.fetch_sub(1, Ordering::Release);
130        if prev == 1 && self.submission_done.load(Ordering::Acquire) {
131            self.done.store(true, Ordering::Release);
132        }
133    }
134
135    /// Returns whether all work is done.
136    #[must_use]
137    pub fn is_done(&self) -> bool {
138        self.done.load(Ordering::Acquire)
139    }
140
141    /// Returns whether submission is complete.
142    #[must_use]
143    pub fn is_submission_done(&self) -> bool {
144        self.submission_done.load(Ordering::Acquire)
145    }
146
147    /// Returns the number of active (in-progress) morsels.
148    #[must_use]
149    pub fn active_count(&self) -> usize {
150        self.active_morsels.load(Ordering::Relaxed)
151    }
152
153    /// Returns the total number of morsels submitted.
154    #[must_use]
155    pub fn total_submitted(&self) -> usize {
156        self.total_submitted.load(Ordering::Relaxed)
157    }
158}
159
160impl std::fmt::Debug for MorselScheduler {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("MorselScheduler")
163            .field("num_workers", &self.num_workers)
164            .field(
165                "active_morsels",
166                &self.active_morsels.load(Ordering::Relaxed),
167            )
168            .field(
169                "total_submitted",
170                &self.total_submitted.load(Ordering::Relaxed),
171            )
172            .field(
173                "submission_done",
174                &self.submission_done.load(Ordering::Relaxed),
175            )
176            .field("done", &self.done.load(Ordering::Relaxed))
177            .finish()
178    }
179}
180
181/// Handle for a worker to interact with the scheduler.
182///
183/// Provides a simpler API for workers with integrated work-stealing.
184pub struct WorkerHandle {
185    scheduler: Arc<MorselScheduler>,
186    worker_id: usize,
187    local_queue: Worker<Morsel>,
188}
189
190impl WorkerHandle {
191    /// Creates a new worker handle and registers with the scheduler.
192    #[must_use]
193    pub fn new(scheduler: Arc<MorselScheduler>) -> Self {
194        let local_queue = Worker::new_fifo();
195        let worker_id = scheduler.register_worker(local_queue.stealer());
196        Self {
197            scheduler,
198            worker_id,
199            local_queue,
200        }
201    }
202
203    /// Gets the next morsel to process.
204    ///
205    /// Tries: local queue -> global queue -> steal from others
206    pub fn get_work(&self) -> Option<Morsel> {
207        // Try local queue first
208        if let Some(morsel) = self.local_queue.pop() {
209            return Some(morsel);
210        }
211
212        // Try global queue
213        if let Some(morsel) = self.scheduler.get_global_work() {
214            return Some(morsel);
215        }
216
217        // Try stealing from others
218        if let Some(morsel) = self.scheduler.steal_work(self.worker_id) {
219            return Some(morsel);
220        }
221
222        // Check if we're done
223        if self.scheduler.is_submission_done() && self.scheduler.active_count() == 0 {
224            return None;
225        }
226
227        None
228    }
229
230    /// Pushes a morsel to this worker's local queue.
231    pub fn push_local(&self, morsel: Morsel) {
232        self.local_queue.push(morsel);
233        self.scheduler
234            .active_morsels
235            .fetch_add(1, Ordering::Relaxed);
236    }
237
238    /// Marks the current morsel as complete.
239    pub fn complete_morsel(&self) {
240        self.scheduler.complete_morsel();
241    }
242
243    /// Returns the worker ID.
244    #[must_use]
245    pub fn worker_id(&self) -> usize {
246        self.worker_id
247    }
248
249    /// Returns whether all work is done.
250    #[must_use]
251    pub fn is_done(&self) -> bool {
252        self.scheduler.is_done()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_scheduler_creation() {
262        let scheduler = MorselScheduler::new(4);
263        assert_eq!(scheduler.num_workers(), 4);
264        assert_eq!(scheduler.active_count(), 0);
265        assert!(!scheduler.is_done());
266    }
267
268    #[test]
269    fn test_submit_and_get_work() {
270        let scheduler = Arc::new(MorselScheduler::new(2));
271
272        scheduler.submit(Morsel::new(0, 0, 0, 1000));
273        scheduler.submit(Morsel::new(1, 0, 1000, 2000));
274        assert_eq!(scheduler.total_submitted(), 2);
275        assert_eq!(scheduler.active_count(), 2);
276
277        // Get work from global queue
278        let morsel = scheduler.get_global_work().unwrap();
279        assert_eq!(morsel.id, 0);
280
281        // Complete the morsel
282        scheduler.complete_morsel();
283        assert_eq!(scheduler.active_count(), 1);
284
285        // Get more work
286        let morsel = scheduler.get_global_work().unwrap();
287        assert_eq!(morsel.id, 1);
288        scheduler.complete_morsel();
289
290        scheduler.finish_submission();
291        assert!(scheduler.is_done());
292    }
293
294    #[test]
295    fn test_submit_batch() {
296        let scheduler = MorselScheduler::new(4);
297
298        let morsels = vec![
299            Morsel::new(0, 0, 0, 100),
300            Morsel::new(1, 0, 100, 200),
301            Morsel::new(2, 0, 200, 300),
302        ];
303        scheduler.submit_batch(morsels);
304
305        assert_eq!(scheduler.total_submitted(), 3);
306        assert_eq!(scheduler.active_count(), 3);
307    }
308
309    #[test]
310    fn test_worker_handle() {
311        let scheduler = Arc::new(MorselScheduler::new(2));
312
313        let handle = WorkerHandle::new(Arc::clone(&scheduler));
314        assert_eq!(handle.worker_id(), 0);
315        assert!(!handle.is_done());
316
317        scheduler.submit(Morsel::new(0, 0, 0, 100));
318
319        let morsel = handle.get_work().unwrap();
320        assert_eq!(morsel.id, 0);
321
322        handle.complete_morsel();
323        scheduler.finish_submission();
324
325        assert!(handle.is_done());
326    }
327
328    #[test]
329    fn test_worker_local_queue() {
330        let scheduler = Arc::new(MorselScheduler::new(2));
331        let handle = WorkerHandle::new(Arc::clone(&scheduler));
332
333        // Push to local queue
334        handle.push_local(Morsel::new(0, 0, 0, 100));
335
336        // Should get it from local queue
337        let morsel = handle.get_work().unwrap();
338        assert_eq!(morsel.id, 0);
339    }
340
341    #[test]
342    fn test_work_stealing() {
343        let scheduler = Arc::new(MorselScheduler::new(2));
344
345        // Create two workers
346        let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
347        let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
348
349        // Push multiple items to worker 1's local queue
350        for i in 0..5 {
351            handle1.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
352        }
353
354        // Worker 1 takes one
355        let _ = handle1.get_work().unwrap();
356
357        // Worker 2 should be able to steal
358        let stolen = handle2.get_work();
359        assert!(stolen.is_some());
360    }
361
362    #[test]
363    fn test_concurrent_workers() {
364        use std::thread;
365
366        let scheduler = Arc::new(MorselScheduler::new(4));
367        let total_morsels = 100;
368
369        // Submit morsels
370        for i in 0..total_morsels {
371            scheduler.submit(Morsel::new(i, 0, i * 100, (i + 1) * 100));
372        }
373        scheduler.finish_submission();
374
375        // Spawn workers
376        let completed = Arc::new(AtomicUsize::new(0));
377        let mut handles = Vec::new();
378
379        for _ in 0..4 {
380            let sched = Arc::clone(&scheduler);
381            let completed = Arc::clone(&completed);
382
383            handles.push(thread::spawn(move || {
384                let handle = WorkerHandle::new(sched);
385                let mut count = 0;
386                while let Some(_morsel) = handle.get_work() {
387                    count += 1;
388                    handle.complete_morsel();
389                }
390                completed.fetch_add(count, Ordering::Relaxed);
391            }));
392        }
393
394        for handle in handles {
395            handle.join().unwrap();
396        }
397
398        assert_eq!(completed.load(Ordering::Relaxed), total_morsels);
399    }
400}