moduvex_runtime/executor/
work_stealing.rs1use std::sync::{Arc, Mutex};
11
12use super::scheduler::{GlobalQueue, LocalQueue};
13
14pub(crate) struct StealableQueue {
22 inner: Mutex<LocalQueue>,
23}
24
25impl StealableQueue {
26 pub(crate) fn new() -> Self {
27 Self {
28 inner: Mutex::new(LocalQueue::new()),
29 }
30 }
31
32 pub(crate) fn local_mut(&self) -> std::sync::MutexGuard<'_, LocalQueue> {
38 self.inner.lock().unwrap()
39 }
40
41 pub(crate) fn steal_from(
47 &self,
48 dest_local: &mut LocalQueue,
49 dest_global: &Arc<GlobalQueue>,
50 ) -> usize {
51 let mut src = self.inner.lock().unwrap();
52 let count = src.len() / 2;
53 if count == 0 {
54 return 0;
55 }
56
57 let mut batch = Vec::with_capacity(count);
58 src.drain_front(&mut batch, count);
59 drop(src); let mut stolen = 0;
62 for header in batch {
63 if let Some(overflow) = dest_local.push(header) {
65 dest_global.push_header(overflow);
66 }
67 stolen += 1;
68 }
69 stolen
70 }
71
72 pub(crate) fn len(&self) -> usize {
74 self.inner.lock().unwrap().len()
75 }
76
77 pub(crate) fn is_empty(&self) -> bool {
79 self.len() == 0
80 }
81}
82
83pub(crate) struct WorkStealingPool {
91 queues: Vec<Arc<StealableQueue>>,
92}
93
94impl WorkStealingPool {
95 pub(crate) fn new() -> Self {
96 Self { queues: Vec::new() }
97 }
98
99 pub(crate) fn add_worker(&mut self, queue: Arc<StealableQueue>) {
101 self.queues.push(queue);
102 }
103
104 pub(crate) fn steal_one(
109 &self,
110 self_idx: usize,
111 dest_local: &mut LocalQueue,
112 dest_global: &Arc<GlobalQueue>,
113 ) -> usize {
114 for (idx, queue) in self.queues.iter().enumerate() {
115 if idx == self_idx {
116 continue;
117 }
118 let n = queue.steal_from(dest_local, dest_global);
119 if n > 0 {
120 return n;
121 }
122 }
123 0
124 }
125
126 pub(crate) fn worker_count(&self) -> usize {
128 self.queues.len()
129 }
130}
131
132#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::executor::task::{Task, TaskHeader};
138
139 fn make_header() -> Arc<TaskHeader> {
140 let (task, _jh) = Task::new(async { 0u32 });
141 Arc::clone(&task.header)
142 }
143
144 #[test]
145 fn steal_from_empty_returns_zero() {
146 let src = StealableQueue::new();
147 let mut dest = LocalQueue::new();
148 let gq = Arc::new(GlobalQueue::new());
149 assert_eq!(src.steal_from(&mut dest, &gq), 0);
150 }
151
152 #[test]
153 fn steal_from_takes_half() {
154 let src = StealableQueue::new();
155 {
156 let mut local = src.local_mut();
157 for _ in 0..8 {
158 local.push(make_header());
159 }
160 }
161 let mut dest = LocalQueue::new();
162 let gq = Arc::new(GlobalQueue::new());
163 let stolen = src.steal_from(&mut dest, &gq);
164 assert_eq!(stolen, 4, "should steal exactly half of 8");
165 assert_eq!(src.len(), 4, "source should retain the other half");
166 }
167
168 #[test]
169 fn pool_steal_skips_self() {
170 let q0 = Arc::new(StealableQueue::new());
171 let q1 = Arc::new(StealableQueue::new());
172 {
173 let mut local = q1.local_mut();
174 for _ in 0..4 {
175 local.push(make_header());
176 }
177 }
178 let mut pool = WorkStealingPool::new();
179 pool.add_worker(Arc::clone(&q0));
180 pool.add_worker(Arc::clone(&q1));
181
182 let mut dest = LocalQueue::new();
183 let gq = Arc::new(GlobalQueue::new());
184 let n = pool.steal_one(0, &mut dest, &gq);
186 assert!(n >= 1, "should steal from q1");
187 assert_eq!(q0.len(), 0, "worker 0's own queue untouched");
188 }
189
190 #[test]
191 fn local_mut_exclusive_access() {
192 let sq = StealableQueue::new();
193 {
194 let mut local = sq.local_mut();
195 assert!(local.push(make_header()).is_none());
196 assert_eq!(local.len(), 1);
197 }
198 assert_eq!(sq.len(), 1);
199 }
200}