moduvex_runtime/executor/
work_stealing.rs1use std::sync::{Arc, Mutex};
11
12use super::scheduler::{GlobalQueue, LocalQueue};
13
14pub(crate) struct StealableQueue {
22 inner: Mutex<LocalQueue>,
23}
24
25impl StealableQueue {
26 pub(crate) fn new() -> Self {
27 Self {
28 inner: Mutex::new(LocalQueue::new()),
29 }
30 }
31
32 pub(crate) fn local_mut(&self) -> std::sync::MutexGuard<'_, LocalQueue> {
38 self.inner.lock().unwrap()
39 }
40
41 pub(crate) fn steal_from(
47 &self,
48 dest_local: &mut LocalQueue,
49 dest_global: &Arc<GlobalQueue>,
50 ) -> usize {
51 let mut src = self.inner.lock().unwrap();
52 let count = src.len() / 2;
53 if count == 0 {
54 return 0;
55 }
56
57 let mut batch = Vec::with_capacity(count);
58 src.drain_front(&mut batch, count);
59 drop(src); let mut stolen = 0;
62 for header in batch {
63 if let Some(overflow) = dest_local.push(header) {
65 dest_global.push_header(overflow);
66 }
67 stolen += 1;
68 }
69 stolen
70 }
71
72 pub(crate) fn len(&self) -> usize {
74 self.inner.lock().unwrap().len()
75 }
76
77 pub(crate) fn is_empty(&self) -> bool {
79 self.len() == 0
80 }
81}
82
83pub(crate) struct WorkStealingPool {
91 queues: Vec<Arc<StealableQueue>>,
92}
93
94impl WorkStealingPool {
95 pub(crate) fn new() -> Self {
96 Self { queues: Vec::new() }
97 }
98
99 pub(crate) fn add_worker(&mut self, queue: Arc<StealableQueue>) {
101 self.queues.push(queue);
102 }
103
104 pub(crate) fn steal_one(
109 &self,
110 self_idx: usize,
111 dest_local: &mut LocalQueue,
112 dest_global: &Arc<GlobalQueue>,
113 ) -> usize {
114 for (idx, queue) in self.queues.iter().enumerate() {
115 if idx == self_idx {
116 continue;
117 }
118 let n = queue.steal_from(dest_local, dest_global);
119 if n > 0 {
120 return n;
121 }
122 }
123 0
124 }
125
126 pub(crate) fn worker_count(&self) -> usize {
128 self.queues.len()
129 }
130}
131
132#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::executor::task::{Task, TaskHeader};
138
139 fn make_header() -> Arc<TaskHeader> {
140 let (task, _jh) = Task::new(async { 0u32 });
141 Arc::clone(&task.header)
142 }
143
144 #[test]
145 fn steal_from_empty_returns_zero() {
146 let src = StealableQueue::new();
147 let mut dest = LocalQueue::new();
148 let gq = Arc::new(GlobalQueue::new());
149 assert_eq!(src.steal_from(&mut dest, &gq), 0);
150 }
151
152 #[test]
153 fn steal_from_takes_half() {
154 let src = StealableQueue::new();
155 {
156 let mut local = src.local_mut();
157 for _ in 0..8 {
158 local.push(make_header());
159 }
160 }
161 let mut dest = LocalQueue::new();
162 let gq = Arc::new(GlobalQueue::new());
163 let stolen = src.steal_from(&mut dest, &gq);
164 assert_eq!(stolen, 4, "should steal exactly half of 8");
165 assert_eq!(src.len(), 4, "source should retain the other half");
166 }
167
168 #[test]
169 fn pool_steal_skips_self() {
170 let q0 = Arc::new(StealableQueue::new());
171 let q1 = Arc::new(StealableQueue::new());
172 {
173 let mut local = q1.local_mut();
174 for _ in 0..4 {
175 local.push(make_header());
176 }
177 }
178 let mut pool = WorkStealingPool::new();
179 pool.add_worker(Arc::clone(&q0));
180 pool.add_worker(Arc::clone(&q1));
181
182 let mut dest = LocalQueue::new();
183 let gq = Arc::new(GlobalQueue::new());
184 let n = pool.steal_one(0, &mut dest, &gq);
186 assert!(n >= 1, "should steal from q1");
187 assert_eq!(q0.len(), 0, "worker 0's own queue untouched");
188 }
189
190 #[test]
191 fn local_mut_exclusive_access() {
192 let sq = StealableQueue::new();
193 {
194 let mut local = sq.local_mut();
195 assert!(local.push(make_header()).is_none());
196 assert_eq!(local.len(), 1);
197 }
198 assert_eq!(sq.len(), 1);
199 }
200
201 #[test]
204 fn steal_from_single_item_queue_returns_zero() {
205 let src = StealableQueue::new();
207 src.local_mut().push(make_header());
208 let mut dest = LocalQueue::new();
209 let gq = Arc::new(GlobalQueue::new());
210 let stolen = src.steal_from(&mut dest, &gq);
211 assert_eq!(stolen, 0, "can't steal half of 1 task");
212 }
213
214 #[test]
215 fn stealable_queue_is_empty_and_len() {
216 let sq = StealableQueue::new();
217 assert!(sq.is_empty());
218 assert_eq!(sq.len(), 0);
219 sq.local_mut().push(make_header());
220 assert!(!sq.is_empty());
221 assert_eq!(sq.len(), 1);
222 }
223
224 #[test]
225 fn pool_worker_count() {
226 let mut pool = WorkStealingPool::new();
227 assert_eq!(pool.worker_count(), 0);
228 pool.add_worker(Arc::new(StealableQueue::new()));
229 assert_eq!(pool.worker_count(), 1);
230 pool.add_worker(Arc::new(StealableQueue::new()));
231 assert_eq!(pool.worker_count(), 2);
232 }
233
234 #[test]
235 fn pool_all_empty_returns_zero() {
236 let mut pool = WorkStealingPool::new();
237 pool.add_worker(Arc::new(StealableQueue::new()));
238 pool.add_worker(Arc::new(StealableQueue::new()));
239 let mut dest = LocalQueue::new();
240 let gq = Arc::new(GlobalQueue::new());
241 assert_eq!(pool.steal_one(0, &mut dest, &gq), 0);
242 }
243
244 #[test]
245 fn steal_many_items_distributes_half() {
246 let src = StealableQueue::new();
247 {
248 let mut local = src.local_mut();
249 for _ in 0..20 {
250 local.push(make_header());
251 }
252 }
253 let mut dest = LocalQueue::new();
254 let gq = Arc::new(GlobalQueue::new());
255 let stolen = src.steal_from(&mut dest, &gq);
256 assert_eq!(stolen, 10);
257 assert_eq!(src.len(), 10);
258 }
259
260 #[test]
261 fn pool_steal_only_from_non_empty_worker() {
262 let q0 = Arc::new(StealableQueue::new()); let q1 = Arc::new(StealableQueue::new()); let q2 = Arc::new(StealableQueue::new()); for _ in 0..4 {
266 q2.local_mut().push(make_header());
267 }
268 let mut pool = WorkStealingPool::new();
269 pool.add_worker(Arc::clone(&q0));
270 pool.add_worker(Arc::clone(&q1));
271 pool.add_worker(Arc::clone(&q2));
272
273 let mut dest = LocalQueue::new();
274 let gq = Arc::new(GlobalQueue::new());
275 let n = pool.steal_one(0, &mut dest, &gq);
277 assert!(n >= 1, "should steal from q2");
278 assert_eq!(q0.len(), 0);
279 assert_eq!(q1.len(), 0);
280 }
281
282 #[test]
283 fn steal_from_2_items_steals_1() {
284 let src = StealableQueue::new();
285 src.local_mut().push(make_header());
286 src.local_mut().push(make_header());
287 let mut dest = LocalQueue::new();
288 let gq = Arc::new(GlobalQueue::new());
289 let stolen = src.steal_from(&mut dest, &gq);
290 assert_eq!(stolen, 1);
291 assert_eq!(src.len(), 1);
292 }
293
294 #[test]
295 fn stealable_queue_len_after_pop() {
296 let sq = StealableQueue::new();
297 sq.local_mut().push(make_header());
298 sq.local_mut().push(make_header());
299 assert_eq!(sq.len(), 2);
300 sq.local_mut().pop();
301 assert_eq!(sq.len(), 1);
302 sq.local_mut().pop();
303 assert_eq!(sq.len(), 0);
304 assert!(sq.is_empty());
305 }
306
307 #[test]
308 fn pool_new_has_zero_workers() {
309 let pool = WorkStealingPool::new();
310 assert_eq!(pool.worker_count(), 0);
311 }
312
313 #[test]
314 fn pool_steal_one_no_workers_returns_zero() {
315 let pool = WorkStealingPool::new();
316 let mut dest = LocalQueue::new();
317 let gq = Arc::new(GlobalQueue::new());
318 assert_eq!(pool.steal_one(0, &mut dest, &gq), 0);
320 }
321
322 #[test]
323 fn pool_steal_skips_self_when_self_has_work() {
324 let q0 = Arc::new(StealableQueue::new()); let q1 = Arc::new(StealableQueue::new()); for _ in 0..8 {
327 q0.local_mut().push(make_header());
328 }
329 let mut pool = WorkStealingPool::new();
330 pool.add_worker(Arc::clone(&q0));
331 pool.add_worker(Arc::clone(&q1));
332
333 let mut dest = LocalQueue::new();
334 let gq = Arc::new(GlobalQueue::new());
335 let n = pool.steal_one(0, &mut dest, &gq);
337 assert_eq!(n, 0, "q1 is empty; should not steal from self");
338 assert_eq!(q0.len(), 8, "q0 unchanged");
339 }
340
341 #[test]
342 fn steal_from_6_items_steals_3() {
343 let src = StealableQueue::new();
344 for _ in 0..6 {
345 src.local_mut().push(make_header());
346 }
347 let mut dest = LocalQueue::new();
348 let gq = Arc::new(GlobalQueue::new());
349 let stolen = src.steal_from(&mut dest, &gq);
350 assert_eq!(stolen, 3);
351 assert_eq!(src.len(), 3);
352 }
353
354 #[test]
355 fn pool_steal_one_worker_2_non_self_returns_from_second() {
356 let q0 = Arc::new(StealableQueue::new());
357 let q1 = Arc::new(StealableQueue::new());
358 for _ in 0..4 {
360 q1.local_mut().push(make_header());
361 }
362 let mut pool = WorkStealingPool::new();
363 pool.add_worker(Arc::clone(&q0));
364 pool.add_worker(Arc::clone(&q1));
365
366 let mut dest = LocalQueue::new();
367 let gq = Arc::new(GlobalQueue::new());
368 let n = pool.steal_one(0, &mut dest, &gq);
369 assert_eq!(n, 2, "should steal 2 from q1 (half of 4)");
370 assert_eq!(q1.len(), 2);
371 }
372
373 #[test]
374 fn stealable_queue_push_16_items() {
375 let sq = StealableQueue::new();
376 for _ in 0..16 {
377 sq.local_mut().push(make_header());
378 }
379 assert_eq!(sq.len(), 16);
380 assert!(!sq.is_empty());
381 }
382
383 #[test]
384 fn pool_3_workers_steal_from_last() {
385 let q0 = Arc::new(StealableQueue::new());
386 let q1 = Arc::new(StealableQueue::new());
387 let q2 = Arc::new(StealableQueue::new());
388 for _ in 0..10 {
389 q2.local_mut().push(make_header());
390 }
391 let mut pool = WorkStealingPool::new();
392 pool.add_worker(Arc::clone(&q0));
393 pool.add_worker(Arc::clone(&q1));
394 pool.add_worker(Arc::clone(&q2));
395
396 let mut dest = LocalQueue::new();
397 let gq = Arc::new(GlobalQueue::new());
398 let n = pool.steal_one(0, &mut dest, &gq);
399 assert_eq!(n, 5);
401 assert_eq!(q2.len(), 5);
402 }
403}