Skip to main content

moduvex_runtime/executor/
work_stealing.rs

1//! Opt-in work-stealing layer.
2//!
3//! `StealableQueue` wraps a `LocalQueue` and exposes a `steal_from` method
4//! that lets another worker grab half of its tasks. This module is intentionally
5//! minimal — the Chase-Lev deque optimisation is deferred to a later phase.
6//!
7//! For the current single-threaded executor the stealing path is never hot, so
8//! correctness and clarity take priority over lock-free performance.
9
10use std::sync::{Arc, Mutex};
11
12use super::scheduler::{GlobalQueue, LocalQueue};
13
14// ── StealableQueue ────────────────────────────────────────────────────────────
15
16/// A `LocalQueue` guarded by a `Mutex` so other workers can steal from it.
17///
18/// The owning worker holds a `&mut StealableQueue` while running, which gives
19/// exclusive (lock-free) access via `local_mut()`. Stealers lock the mutex only
20/// when attempting to steal, which is the infrequent path.
21pub(crate) struct StealableQueue {
22    inner: Mutex<LocalQueue>,
23}
24
25impl StealableQueue {
26    pub(crate) fn new() -> Self {
27        Self {
28            inner: Mutex::new(LocalQueue::new()),
29        }
30    }
31
32    /// Exclusive mutable access for the owning worker.
33    ///
34    /// # Panics
35    /// Panics if the mutex is poisoned (i.e. a previous worker thread panicked
36    /// while holding the lock). This is a non-recoverable programming error.
37    pub(crate) fn local_mut(&self) -> std::sync::MutexGuard<'_, LocalQueue> {
38        self.inner.lock().unwrap()
39    }
40
41    /// Steal up to half of the tasks in this queue into `dest_local`.
42    ///
43    /// Returns the number of tasks actually stolen. Returns 0 if the queue is
44    /// empty or if the destination local queue overflows (unlikely given the
45    /// 256-slot capacity).
46    pub(crate) fn steal_from(
47        &self,
48        dest_local: &mut LocalQueue,
49        dest_global: &Arc<GlobalQueue>,
50    ) -> usize {
51        let mut src = self.inner.lock().unwrap();
52        let count = src.len() / 2;
53        if count == 0 {
54            return 0;
55        }
56
57        let mut batch = Vec::with_capacity(count);
58        src.drain_front(&mut batch, count);
59        drop(src); // release lock before pushing to dest
60
61        let mut stolen = 0;
62        for header in batch {
63            // Try local first; spill overflow to global.
64            if let Some(overflow) = dest_local.push(header) {
65                dest_global.push_header(overflow);
66            }
67            stolen += 1;
68        }
69        stolen
70    }
71
72    /// Number of tasks currently in the queue (acquires the lock briefly).
73    pub(crate) fn len(&self) -> usize {
74        self.inner.lock().unwrap().len()
75    }
76
77    /// `true` if the queue is empty (acquires the lock briefly).
78    pub(crate) fn is_empty(&self) -> bool {
79        self.len() == 0
80    }
81}
82
83// ── WorkStealingPool ──────────────────────────────────────────────────────────
84
85/// Registry of per-worker `StealableQueue`s that enables cross-worker stealing.
86///
87/// In the current single-threaded executor this pool has exactly one entry.
88/// Multi-worker support (spawning N threads each with their own worker) will
89/// populate this pool with N entries and use random victim selection.
90pub(crate) struct WorkStealingPool {
91    queues: Vec<Arc<StealableQueue>>,
92}
93
94impl WorkStealingPool {
95    pub(crate) fn new() -> Self {
96        Self { queues: Vec::new() }
97    }
98
99    /// Register a worker's queue with the pool.
100    pub(crate) fn add_worker(&mut self, queue: Arc<StealableQueue>) {
101        self.queues.push(queue);
102    }
103
104    /// Attempt to steal from any worker other than `self_idx`.
105    ///
106    /// Uses a simple linear scan (no randomisation needed for single-worker).
107    /// Returns the number of tasks stolen, or 0 if all queues were empty.
108    pub(crate) fn steal_one(
109        &self,
110        self_idx: usize,
111        dest_local: &mut LocalQueue,
112        dest_global: &Arc<GlobalQueue>,
113    ) -> usize {
114        for (idx, queue) in self.queues.iter().enumerate() {
115            if idx == self_idx {
116                continue;
117            }
118            let n = queue.steal_from(dest_local, dest_global);
119            if n > 0 {
120                return n;
121            }
122        }
123        0
124    }
125
126    /// Number of registered workers.
127    pub(crate) fn worker_count(&self) -> usize {
128        self.queues.len()
129    }
130}
131
132// ── Tests ─────────────────────────────────────────────────────────────────────
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::executor::task::{Task, TaskHeader};
138
139    fn make_header() -> Arc<TaskHeader> {
140        let (task, _jh) = Task::new(async { 0u32 });
141        Arc::clone(&task.header)
142    }
143
144    #[test]
145    fn steal_from_empty_returns_zero() {
146        let src = StealableQueue::new();
147        let mut dest = LocalQueue::new();
148        let gq = Arc::new(GlobalQueue::new());
149        assert_eq!(src.steal_from(&mut dest, &gq), 0);
150    }
151
152    #[test]
153    fn steal_from_takes_half() {
154        let src = StealableQueue::new();
155        {
156            let mut local = src.local_mut();
157            for _ in 0..8 {
158                local.push(make_header());
159            }
160        }
161        let mut dest = LocalQueue::new();
162        let gq = Arc::new(GlobalQueue::new());
163        let stolen = src.steal_from(&mut dest, &gq);
164        assert_eq!(stolen, 4, "should steal exactly half of 8");
165        assert_eq!(src.len(), 4, "source should retain the other half");
166    }
167
168    #[test]
169    fn pool_steal_skips_self() {
170        let q0 = Arc::new(StealableQueue::new());
171        let q1 = Arc::new(StealableQueue::new());
172        {
173            let mut local = q1.local_mut();
174            for _ in 0..4 {
175                local.push(make_header());
176            }
177        }
178        let mut pool = WorkStealingPool::new();
179        pool.add_worker(Arc::clone(&q0));
180        pool.add_worker(Arc::clone(&q1));
181
182        let mut dest = LocalQueue::new();
183        let gq = Arc::new(GlobalQueue::new());
184        // Worker 0 tries to steal; skips itself (idx=0), steals from q1 (idx=1).
185        let n = pool.steal_one(0, &mut dest, &gq);
186        assert!(n >= 1, "should steal from q1");
187        assert_eq!(q0.len(), 0, "worker 0's own queue untouched");
188    }
189
190    #[test]
191    fn local_mut_exclusive_access() {
192        let sq = StealableQueue::new();
193        {
194            let mut local = sq.local_mut();
195            assert!(local.push(make_header()).is_none());
196            assert_eq!(local.len(), 1);
197        }
198        assert_eq!(sq.len(), 1);
199    }
200}