open_coroutine_queue/work_steal.rs
1use crate::rand::{FastRand, RngSeedGenerator};
2use crossbeam_deque::{Injector, Steal};
3use st3::fifo::Worker;
4use std::collections::VecDeque;
5use std::fmt::Debug;
6use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
7
8/// Work stealing global queue, shared by multiple threads.
9#[repr(C)]
10#[derive(Debug)]
11pub struct WorkStealQueue<T: Debug> {
12 shared_queue: Injector<T>,
13 /// Number of pending tasks in the queue. This helps prevent unnecessary
14 /// locking in the hot path.
15 len: AtomicUsize,
16 local_queues: VecDeque<Worker<T>>,
17 index: AtomicUsize,
18 seed_generator: RngSeedGenerator,
19}
20
21impl<T: Debug> Drop for WorkStealQueue<T> {
22 fn drop(&mut self) {
23 if !std::thread::panicking() {
24 for local_queue in &self.local_queues {
25 assert!(local_queue.pop().is_none(), "local queue not empty");
26 }
27 assert!(self.pop().is_none(), "global queue not empty");
28 }
29 }
30}
31
32impl<T: Debug> WorkStealQueue<T> {
33 /// Get a global `WorkStealQueue` instance.
34 #[allow(unsafe_code, trivial_casts, box_pointers)]
35 pub fn get_instance<'s>() -> &'s WorkStealQueue<T> {
36 static INSTANCE: AtomicUsize = AtomicUsize::new(0);
37 let mut ret = INSTANCE.load(Ordering::Relaxed);
38 if ret == 0 {
39 let ptr: &'s mut WorkStealQueue<T> = Box::leak(Box::default());
40 ret = ptr as *mut WorkStealQueue<T> as usize;
41 INSTANCE.store(ret, Ordering::Relaxed);
42 }
43 unsafe { &*(ret as *mut WorkStealQueue<T>) }
44 }
45
46 /// Create a new `WorkStealQueue` instance.
47 #[must_use]
48 pub fn new(local_queues_size: usize, local_capacity: usize) -> Self {
49 WorkStealQueue {
50 shared_queue: Injector::new(),
51 len: AtomicUsize::new(0),
52 local_queues: (0..local_queues_size)
53 .map(|_| Worker::new(local_capacity))
54 .collect(),
55 index: AtomicUsize::new(0),
56 seed_generator: RngSeedGenerator::default(),
57 }
58 }
59
60 /// Returns `true` if the global queue is empty.
61 pub fn is_empty(&self) -> bool {
62 self.len() == 0
63 }
64
65 /// Get the size of the global queue.
66 pub fn len(&self) -> usize {
67 self.len.load(Ordering::Acquire)
68 }
69
70 /// Push an element to the global queue.
71 pub fn push(&self, item: T) {
72 self.shared_queue.push(item);
73 //add count
74 self.len.store(self.len() + 1, Ordering::Release);
75 }
76
77 /// Pop an element from the global queue.
78 pub fn pop(&self) -> Option<T> {
79 // Fast path, if len == 0, then there are no values
80 if self.is_empty() {
81 return None;
82 }
83 loop {
84 match self.shared_queue.steal() {
85 Steal::Success(item) => {
86 // Decrement the count.
87 self.len.store(self.len() - 1, Ordering::Release);
88 return Some(item);
89 }
90 Steal::Retry => continue,
91 Steal::Empty => return None,
92 }
93 }
94 }
95
96 /// Get a local queue, this method should be called up to `local_queue_size` times.
97 ///
98 /// # Panics
99 /// should never happens
100 pub fn local_queue(&self) -> LocalQueue<'_, T> {
101 let mut index = self.index.fetch_add(1, Ordering::Relaxed);
102 if index == usize::MAX {
103 self.index.store(0, Ordering::Relaxed);
104 }
105 index %= self.local_queues.len();
106 let local = self
107 .local_queues
108 .get(index)
109 .unwrap_or_else(|| panic!("local queue {index} init failed!"));
110 LocalQueue::new(self, local, FastRand::new(self.seed_generator.next_seed()))
111 }
112}
113
114impl<T: Debug> Default for WorkStealQueue<T> {
115 fn default() -> Self {
116 Self::new(num_cpus::get(), 256)
117 }
118}
119
120/// The work stealing local queue, exclusive to thread.
121#[repr(C)]
122#[derive(Debug)]
123pub struct LocalQueue<'l, T: Debug> {
124 /// Used to schedule bookkeeping tasks every so often.
125 tick: AtomicU32,
126 shared: &'l WorkStealQueue<T>,
127 stealing: AtomicBool,
128 queue: &'l Worker<T>,
129 /// Fast random number generator.
130 rand: FastRand,
131}
132
133impl<T: Debug> Default for LocalQueue<'_, T> {
134 fn default() -> Self {
135 WorkStealQueue::get_instance().local_queue()
136 }
137}
138
139impl<T: Debug> Drop for LocalQueue<'_, T> {
140 fn drop(&mut self) {
141 if !std::thread::panicking() {
142 assert!(self.queue.pop().is_none(), "local queue not empty");
143 }
144 }
145}
146
147impl<'l, T: Debug> LocalQueue<'l, T> {
148 fn new(shared: &'l WorkStealQueue<T>, queue: &'l Worker<T>, rand: FastRand) -> Self {
149 LocalQueue {
150 tick: AtomicU32::new(0),
151 shared,
152 stealing: AtomicBool::new(false),
153 queue,
154 rand,
155 }
156 }
157
158 /// Returns `true` if the local queue is empty.
159 pub fn is_empty(&self) -> bool {
160 self.queue.is_empty()
161 }
162
163 /// Returns `true` if the local queue is full.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// use open_coroutine_queue::WorkStealQueue;
169 ///
170 /// let queue = WorkStealQueue::new(1, 2);
171 /// let local = queue.local_queue();
172 /// assert!(local.is_empty());
173 /// for i in 0..2 {
174 /// local.push_back(i);
175 /// }
176 /// assert!(local.is_full());
177 /// assert_eq!(local.pop_front(), Some(0));
178 /// assert_eq!(local.len(), 1);
179 /// assert_eq!(local.pop_front(), Some(1));
180 /// assert_eq!(local.pop_front(), None);
181 /// assert!(local.is_empty());
182 /// ```
183 pub fn is_full(&self) -> bool {
184 self.queue.spare_capacity() == 0
185 }
186
187 /// Returns the number of elements in the queue.
188 pub fn len(&self) -> usize {
189 self.queue.capacity() - self.queue.spare_capacity()
190 }
191
192 fn try_lock(&self) -> bool {
193 self.stealing
194 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
195 .is_ok()
196 }
197
198 fn release_lock(&self) {
199 self.stealing.store(false, Ordering::Release);
200 }
201
202 /// If the queue is full, first push half to global,
203 /// then push the item to global.
204 ///
205 /// # Examples
206 ///
207 /// ```
208 /// use open_coroutine_queue::WorkStealQueue;
209 ///
210 /// let queue = WorkStealQueue::new(1, 2);
211 /// let local = queue.local_queue();
212 /// for i in 0..4 {
213 /// local.push_back(i);
214 /// }
215 /// assert_eq!(local.pop_front(), Some(1));
216 /// assert_eq!(local.pop_front(), Some(3));
217 /// assert_eq!(local.pop_front(), Some(0));
218 /// assert_eq!(local.pop_front(), Some(2));
219 /// assert_eq!(local.pop_front(), None);
220 /// ```
221 pub fn push_back(&self, item: T) {
222 if let Err(item) = self.queue.push(item) {
223 //把本地队列的一半放到全局队列
224 let count = self.len() / 2;
225 for _ in 0..count {
226 if let Some(item) = self.queue.pop() {
227 self.shared.push(item);
228 }
229 }
230 //直接放到全局队列
231 self.shared.push(item);
232 }
233 }
234
235 /// Increment the tick
236 fn tick(&self) -> u32 {
237 let val = self.tick.fetch_add(1, Ordering::Release);
238 if val == u32::MAX {
239 self.tick.store(0, Ordering::Release);
240 return 0;
241 }
242 val + 1
243 }
244
245 /// If the queue is empty, first try steal from global,
246 /// then try steal from siblings.
247 ///
248 /// # Examples
249 ///
250 /// ```
251 /// use open_coroutine_queue::WorkStealQueue;
252 ///
253 /// let queue = WorkStealQueue::new(1, 32);
254 /// for i in 0..4 {
255 /// queue.push(i);
256 /// }
257 /// let local = queue.local_queue();
258 /// for i in 0..4 {
259 /// assert_eq!(local.pop_front(), Some(i));
260 /// }
261 /// assert_eq!(local.pop_front(), None);
262 /// assert_eq!(queue.pop(), None);
263 /// ```
264 ///
265 /// # Examples
266 /// ```
267 /// use open_coroutine_queue::WorkStealQueue;
268 /// let queue = WorkStealQueue::new(2, 64);
269 /// let local0 = queue.local_queue();
270 /// local0.push_back(2);
271 /// local0.push_back(3);
272 /// local0.push_back(4);
273 /// local0.push_back(5);
274 /// let local1 = queue.local_queue();
275 /// local1.push_back(0);
276 /// local1.push_back(1);
277 /// for i in 0..6 {
278 /// assert_eq!(local1.pop_front(), Some(i));
279 /// }
280 /// assert_eq!(local0.pop_front(), None);
281 /// assert_eq!(local1.pop_front(), None);
282 /// assert_eq!(queue.pop(), None);
283 /// ```
284 #[allow(clippy::cast_possible_truncation)]
285 pub fn pop_front(&self) -> Option<T> {
286 //每从本地弹出61次,就从全局队列弹出
287 if self.tick() % 61 == 0 {
288 if let Some(val) = self.shared.pop() {
289 return Some(val);
290 }
291 }
292
293 //从本地队列弹出元素
294 if let Some(val) = self.queue.pop() {
295 return Some(val);
296 }
297 if self.try_lock() {
298 //尝试从其他本地队列steal
299 let local_queues = &self.shared.local_queues;
300 let num = local_queues.len();
301 let start = self.rand.fastrand_n(num as u32) as usize;
302 for i in 0..num {
303 let i = (start + i) % num;
304 if let Some(another) = local_queues.get(i) {
305 if std::ptr::eq(&another, &self.queue) {
306 //不能偷自己
307 continue;
308 }
309 if another.is_empty() {
310 //其他队列为空
311 continue;
312 }
313 if self.queue.spare_capacity() == 0 {
314 //本地队列已满
315 continue;
316 }
317 if another
318 .stealer()
319 .steal(self.queue, |n| {
320 //可偷取的最大长度与本地队列空闲长度做比较
321 n.min(self.queue.spare_capacity())
322 //与其他队列当前长度的一半做比较
323 .min(((another.capacity() - another.spare_capacity()) + 1) / 2)
324 })
325 .is_ok()
326 {
327 self.release_lock();
328 return self.queue.pop();
329 }
330 }
331 }
332 self.release_lock();
333 }
334 //都steal不到,只好从shared里pop
335 self.shared.pop()
336 }
337}