Skip to main content

moduvex_runtime/executor/
work_stealing.rs

1//! Opt-in work-stealing layer.
2//!
3//! `StealableQueue` wraps a `LocalQueue` and exposes a `steal_from` method
4//! that lets another worker grab half of its tasks. This module is intentionally
5//! minimal — the Chase-Lev deque optimisation is deferred to a later phase.
6//!
7//! For the current single-threaded executor the stealing path is never hot, so
8//! correctness and clarity take priority over lock-free performance.
9
10use std::sync::{Arc, Mutex};
11
12use super::scheduler::{GlobalQueue, LocalQueue};
13
14// ── StealableQueue ────────────────────────────────────────────────────────────
15
16/// A `LocalQueue` guarded by a `Mutex` so other workers can steal from it.
17///
18/// The owning worker holds a `&mut StealableQueue` while running, which gives
19/// exclusive (lock-free) access via `local_mut()`. Stealers lock the mutex only
20/// when attempting to steal, which is the infrequent path.
21pub(crate) struct StealableQueue {
22    inner: Mutex<LocalQueue>,
23}
24
25impl StealableQueue {
26    pub(crate) fn new() -> Self {
27        Self {
28            inner: Mutex::new(LocalQueue::new()),
29        }
30    }
31
32    /// Exclusive mutable access for the owning worker.
33    ///
34    /// # Panics
35    /// Panics if the mutex is poisoned (i.e. a previous worker thread panicked
36    /// while holding the lock). This is a non-recoverable programming error.
37    pub(crate) fn local_mut(&self) -> std::sync::MutexGuard<'_, LocalQueue> {
38        self.inner.lock().unwrap()
39    }
40
41    /// Steal up to half of the tasks in this queue into `dest_local`.
42    ///
43    /// Returns the number of tasks actually stolen. Returns 0 if the queue is
44    /// empty or if the destination local queue overflows (unlikely given the
45    /// 256-slot capacity).
46    pub(crate) fn steal_from(
47        &self,
48        dest_local: &mut LocalQueue,
49        dest_global: &Arc<GlobalQueue>,
50    ) -> usize {
51        let mut src = self.inner.lock().unwrap();
52        let count = src.len() / 2;
53        if count == 0 {
54            return 0;
55        }
56
57        let mut batch = Vec::with_capacity(count);
58        src.drain_front(&mut batch, count);
59        drop(src); // release lock before pushing to dest
60
61        let mut stolen = 0;
62        for header in batch {
63            // Try local first; spill overflow to global.
64            if let Some(overflow) = dest_local.push(header) {
65                dest_global.push_header(overflow);
66            }
67            stolen += 1;
68        }
69        stolen
70    }
71
72    /// Number of tasks currently in the queue (acquires the lock briefly).
73    pub(crate) fn len(&self) -> usize {
74        self.inner.lock().unwrap().len()
75    }
76
77    /// `true` if the queue is empty (acquires the lock briefly).
78    pub(crate) fn is_empty(&self) -> bool {
79        self.len() == 0
80    }
81}
82
83// ── WorkStealingPool ──────────────────────────────────────────────────────────
84
85/// Registry of per-worker `StealableQueue`s that enables cross-worker stealing.
86///
87/// In the current single-threaded executor this pool has exactly one entry.
88/// Multi-worker support (spawning N threads each with their own worker) will
89/// populate this pool with N entries and use random victim selection.
90pub(crate) struct WorkStealingPool {
91    queues: Vec<Arc<StealableQueue>>,
92}
93
94impl WorkStealingPool {
95    pub(crate) fn new() -> Self {
96        Self { queues: Vec::new() }
97    }
98
99    /// Register a worker's queue with the pool.
100    pub(crate) fn add_worker(&mut self, queue: Arc<StealableQueue>) {
101        self.queues.push(queue);
102    }
103
104    /// Attempt to steal from any worker other than `self_idx`.
105    ///
106    /// Uses a simple linear scan (no randomisation needed for single-worker).
107    /// Returns the number of tasks stolen, or 0 if all queues were empty.
108    pub(crate) fn steal_one(
109        &self,
110        self_idx: usize,
111        dest_local: &mut LocalQueue,
112        dest_global: &Arc<GlobalQueue>,
113    ) -> usize {
114        for (idx, queue) in self.queues.iter().enumerate() {
115            if idx == self_idx {
116                continue;
117            }
118            let n = queue.steal_from(dest_local, dest_global);
119            if n > 0 {
120                return n;
121            }
122        }
123        0
124    }
125
126    /// Number of registered workers.
127    pub(crate) fn worker_count(&self) -> usize {
128        self.queues.len()
129    }
130}
131
132// ── Tests ─────────────────────────────────────────────────────────────────────
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::executor::task::{Task, TaskHeader};
138
139    fn make_header() -> Arc<TaskHeader> {
140        let (task, _jh) = Task::new(async { 0u32 });
141        Arc::clone(&task.header)
142    }
143
144    #[test]
145    fn steal_from_empty_returns_zero() {
146        let src = StealableQueue::new();
147        let mut dest = LocalQueue::new();
148        let gq = Arc::new(GlobalQueue::new());
149        assert_eq!(src.steal_from(&mut dest, &gq), 0);
150    }
151
152    #[test]
153    fn steal_from_takes_half() {
154        let src = StealableQueue::new();
155        {
156            let mut local = src.local_mut();
157            for _ in 0..8 {
158                local.push(make_header());
159            }
160        }
161        let mut dest = LocalQueue::new();
162        let gq = Arc::new(GlobalQueue::new());
163        let stolen = src.steal_from(&mut dest, &gq);
164        assert_eq!(stolen, 4, "should steal exactly half of 8");
165        assert_eq!(src.len(), 4, "source should retain the other half");
166    }
167
168    #[test]
169    fn pool_steal_skips_self() {
170        let q0 = Arc::new(StealableQueue::new());
171        let q1 = Arc::new(StealableQueue::new());
172        {
173            let mut local = q1.local_mut();
174            for _ in 0..4 {
175                local.push(make_header());
176            }
177        }
178        let mut pool = WorkStealingPool::new();
179        pool.add_worker(Arc::clone(&q0));
180        pool.add_worker(Arc::clone(&q1));
181
182        let mut dest = LocalQueue::new();
183        let gq = Arc::new(GlobalQueue::new());
184        // Worker 0 tries to steal; skips itself (idx=0), steals from q1 (idx=1).
185        let n = pool.steal_one(0, &mut dest, &gq);
186        assert!(n >= 1, "should steal from q1");
187        assert_eq!(q0.len(), 0, "worker 0's own queue untouched");
188    }
189
190    #[test]
191    fn local_mut_exclusive_access() {
192        let sq = StealableQueue::new();
193        {
194            let mut local = sq.local_mut();
195            assert!(local.push(make_header()).is_none());
196            assert_eq!(local.len(), 1);
197        }
198        assert_eq!(sq.len(), 1);
199    }
200
201    // ── Additional work-stealing tests ────────────────────────────────────
202
203    #[test]
204    fn steal_from_single_item_queue_returns_zero() {
205        // Half of 1 = 0 (integer division) — returns 0 stolen
206        let src = StealableQueue::new();
207        src.local_mut().push(make_header());
208        let mut dest = LocalQueue::new();
209        let gq = Arc::new(GlobalQueue::new());
210        let stolen = src.steal_from(&mut dest, &gq);
211        assert_eq!(stolen, 0, "can't steal half of 1 task");
212    }
213
214    #[test]
215    fn stealable_queue_is_empty_and_len() {
216        let sq = StealableQueue::new();
217        assert!(sq.is_empty());
218        assert_eq!(sq.len(), 0);
219        sq.local_mut().push(make_header());
220        assert!(!sq.is_empty());
221        assert_eq!(sq.len(), 1);
222    }
223
224    #[test]
225    fn pool_worker_count() {
226        let mut pool = WorkStealingPool::new();
227        assert_eq!(pool.worker_count(), 0);
228        pool.add_worker(Arc::new(StealableQueue::new()));
229        assert_eq!(pool.worker_count(), 1);
230        pool.add_worker(Arc::new(StealableQueue::new()));
231        assert_eq!(pool.worker_count(), 2);
232    }
233
234    #[test]
235    fn pool_all_empty_returns_zero() {
236        let mut pool = WorkStealingPool::new();
237        pool.add_worker(Arc::new(StealableQueue::new()));
238        pool.add_worker(Arc::new(StealableQueue::new()));
239        let mut dest = LocalQueue::new();
240        let gq = Arc::new(GlobalQueue::new());
241        assert_eq!(pool.steal_one(0, &mut dest, &gq), 0);
242    }
243
244    #[test]
245    fn steal_many_items_distributes_half() {
246        let src = StealableQueue::new();
247        {
248            let mut local = src.local_mut();
249            for _ in 0..20 {
250                local.push(make_header());
251            }
252        }
253        let mut dest = LocalQueue::new();
254        let gq = Arc::new(GlobalQueue::new());
255        let stolen = src.steal_from(&mut dest, &gq);
256        assert_eq!(stolen, 10);
257        assert_eq!(src.len(), 10);
258    }
259
260    #[test]
261    fn pool_steal_only_from_non_empty_worker() {
262        let q0 = Arc::new(StealableQueue::new()); // empty
263        let q1 = Arc::new(StealableQueue::new()); // empty
264        let q2 = Arc::new(StealableQueue::new()); // has tasks
265        for _ in 0..4 {
266            q2.local_mut().push(make_header());
267        }
268        let mut pool = WorkStealingPool::new();
269        pool.add_worker(Arc::clone(&q0));
270        pool.add_worker(Arc::clone(&q1));
271        pool.add_worker(Arc::clone(&q2));
272
273        let mut dest = LocalQueue::new();
274        let gq = Arc::new(GlobalQueue::new());
275        // Worker 0 steals; q1 is empty so steals from q2
276        let n = pool.steal_one(0, &mut dest, &gq);
277        assert!(n >= 1, "should steal from q2");
278        assert_eq!(q0.len(), 0);
279        assert_eq!(q1.len(), 0);
280    }
281
282    #[test]
283    fn steal_from_2_items_steals_1() {
284        let src = StealableQueue::new();
285        src.local_mut().push(make_header());
286        src.local_mut().push(make_header());
287        let mut dest = LocalQueue::new();
288        let gq = Arc::new(GlobalQueue::new());
289        let stolen = src.steal_from(&mut dest, &gq);
290        assert_eq!(stolen, 1);
291        assert_eq!(src.len(), 1);
292    }
293
294    #[test]
295    fn stealable_queue_len_after_pop() {
296        let sq = StealableQueue::new();
297        sq.local_mut().push(make_header());
298        sq.local_mut().push(make_header());
299        assert_eq!(sq.len(), 2);
300        sq.local_mut().pop();
301        assert_eq!(sq.len(), 1);
302        sq.local_mut().pop();
303        assert_eq!(sq.len(), 0);
304        assert!(sq.is_empty());
305    }
306
307    #[test]
308    fn pool_new_has_zero_workers() {
309        let pool = WorkStealingPool::new();
310        assert_eq!(pool.worker_count(), 0);
311    }
312
313    #[test]
314    fn pool_steal_one_no_workers_returns_zero() {
315        let pool = WorkStealingPool::new();
316        let mut dest = LocalQueue::new();
317        let gq = Arc::new(GlobalQueue::new());
318        // No workers at all
319        assert_eq!(pool.steal_one(0, &mut dest, &gq), 0);
320    }
321
322    #[test]
323    fn pool_steal_skips_self_when_self_has_work() {
324        let q0 = Arc::new(StealableQueue::new()); // has work — but is self
325        let q1 = Arc::new(StealableQueue::new()); // empty
326        for _ in 0..8 {
327            q0.local_mut().push(make_header());
328        }
329        let mut pool = WorkStealingPool::new();
330        pool.add_worker(Arc::clone(&q0));
331        pool.add_worker(Arc::clone(&q1));
332
333        let mut dest = LocalQueue::new();
334        let gq = Arc::new(GlobalQueue::new());
335        // Worker 0 tries to steal but only q1 (empty) is eligible
336        let n = pool.steal_one(0, &mut dest, &gq);
337        assert_eq!(n, 0, "q1 is empty; should not steal from self");
338        assert_eq!(q0.len(), 8, "q0 unchanged");
339    }
340
341    #[test]
342    fn steal_from_6_items_steals_3() {
343        let src = StealableQueue::new();
344        for _ in 0..6 {
345            src.local_mut().push(make_header());
346        }
347        let mut dest = LocalQueue::new();
348        let gq = Arc::new(GlobalQueue::new());
349        let stolen = src.steal_from(&mut dest, &gq);
350        assert_eq!(stolen, 3);
351        assert_eq!(src.len(), 3);
352    }
353
354    #[test]
355    fn pool_steal_one_worker_2_non_self_returns_from_second() {
356        let q0 = Arc::new(StealableQueue::new());
357        let q1 = Arc::new(StealableQueue::new());
358        // Give q1 some work
359        for _ in 0..4 {
360            q1.local_mut().push(make_header());
361        }
362        let mut pool = WorkStealingPool::new();
363        pool.add_worker(Arc::clone(&q0));
364        pool.add_worker(Arc::clone(&q1));
365
366        let mut dest = LocalQueue::new();
367        let gq = Arc::new(GlobalQueue::new());
368        let n = pool.steal_one(0, &mut dest, &gq);
369        assert_eq!(n, 2, "should steal 2 from q1 (half of 4)");
370        assert_eq!(q1.len(), 2);
371    }
372
373    #[test]
374    fn stealable_queue_push_16_items() {
375        let sq = StealableQueue::new();
376        for _ in 0..16 {
377            sq.local_mut().push(make_header());
378        }
379        assert_eq!(sq.len(), 16);
380        assert!(!sq.is_empty());
381    }
382
383    #[test]
384    fn pool_3_workers_steal_from_last() {
385        let q0 = Arc::new(StealableQueue::new());
386        let q1 = Arc::new(StealableQueue::new());
387        let q2 = Arc::new(StealableQueue::new());
388        for _ in 0..10 {
389            q2.local_mut().push(make_header());
390        }
391        let mut pool = WorkStealingPool::new();
392        pool.add_worker(Arc::clone(&q0));
393        pool.add_worker(Arc::clone(&q1));
394        pool.add_worker(Arc::clone(&q2));
395
396        let mut dest = LocalQueue::new();
397        let gq = Arc::new(GlobalQueue::new());
398        let n = pool.steal_one(0, &mut dest, &gq);
399        // q0=empty(skip self), q1=empty, q2=10 items → steal 5
400        assert_eq!(n, 5);
401        assert_eq!(q2.len(), 5);
402    }
403}