open_coroutine_queue/
work_steal.rs

1use crate::rand::{FastRand, RngSeedGenerator};
2use crossbeam_deque::{Injector, Steal};
3use st3::fifo::Worker;
4use std::collections::VecDeque;
5use std::fmt::Debug;
6use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
7
8/// Work stealing global queue, shared by multiple threads.
9#[repr(C)]
10#[derive(Debug)]
11pub struct WorkStealQueue<T: Debug> {
12    shared_queue: Injector<T>,
13    /// Number of pending tasks in the queue. This helps prevent unnecessary
14    /// locking in the hot path.
15    len: AtomicUsize,
16    local_queues: VecDeque<Worker<T>>,
17    index: AtomicUsize,
18    seed_generator: RngSeedGenerator,
19}
20
21impl<T: Debug> Drop for WorkStealQueue<T> {
22    fn drop(&mut self) {
23        if !std::thread::panicking() {
24            for local_queue in &self.local_queues {
25                assert!(local_queue.pop().is_none(), "local queue not empty");
26            }
27            assert!(self.pop().is_none(), "global queue not empty");
28        }
29    }
30}
31
32impl<T: Debug> WorkStealQueue<T> {
33    /// Get a global `WorkStealQueue` instance.
34    #[allow(unsafe_code, trivial_casts, box_pointers)]
35    pub fn get_instance<'s>() -> &'s WorkStealQueue<T> {
36        static INSTANCE: AtomicUsize = AtomicUsize::new(0);
37        let mut ret = INSTANCE.load(Ordering::Relaxed);
38        if ret == 0 {
39            let ptr: &'s mut WorkStealQueue<T> = Box::leak(Box::default());
40            ret = ptr as *mut WorkStealQueue<T> as usize;
41            INSTANCE.store(ret, Ordering::Relaxed);
42        }
43        unsafe { &*(ret as *mut WorkStealQueue<T>) }
44    }
45
46    /// Create a new `WorkStealQueue` instance.
47    #[must_use]
48    pub fn new(local_queues_size: usize, local_capacity: usize) -> Self {
49        WorkStealQueue {
50            shared_queue: Injector::new(),
51            len: AtomicUsize::new(0),
52            local_queues: (0..local_queues_size)
53                .map(|_| Worker::new(local_capacity))
54                .collect(),
55            index: AtomicUsize::new(0),
56            seed_generator: RngSeedGenerator::default(),
57        }
58    }
59
60    /// Returns `true` if the global queue is empty.
61    pub fn is_empty(&self) -> bool {
62        self.len() == 0
63    }
64
65    /// Get the size of the global queue.
66    pub fn len(&self) -> usize {
67        self.len.load(Ordering::Acquire)
68    }
69
70    /// Push an element to the global queue.
71    pub fn push(&self, item: T) {
72        self.shared_queue.push(item);
73        //add count
74        self.len.store(self.len() + 1, Ordering::Release);
75    }
76
77    /// Pop an element from the global queue.
78    pub fn pop(&self) -> Option<T> {
79        // Fast path, if len == 0, then there are no values
80        if self.is_empty() {
81            return None;
82        }
83        loop {
84            match self.shared_queue.steal() {
85                Steal::Success(item) => {
86                    // Decrement the count.
87                    self.len.store(self.len() - 1, Ordering::Release);
88                    return Some(item);
89                }
90                Steal::Retry => continue,
91                Steal::Empty => return None,
92            }
93        }
94    }
95
96    /// Get a local queue, this method should be called up to `local_queue_size` times.
97    ///
98    /// # Panics
99    /// should never happens
100    pub fn local_queue(&self) -> LocalQueue<'_, T> {
101        let mut index = self.index.fetch_add(1, Ordering::Relaxed);
102        if index == usize::MAX {
103            self.index.store(0, Ordering::Relaxed);
104        }
105        index %= self.local_queues.len();
106        let local = self
107            .local_queues
108            .get(index)
109            .unwrap_or_else(|| panic!("local queue {index} init failed!"));
110        LocalQueue::new(self, local, FastRand::new(self.seed_generator.next_seed()))
111    }
112}
113
114impl<T: Debug> Default for WorkStealQueue<T> {
115    fn default() -> Self {
116        Self::new(num_cpus::get(), 256)
117    }
118}
119
120/// The work stealing local queue, exclusive to thread.
121#[repr(C)]
122#[derive(Debug)]
123pub struct LocalQueue<'l, T: Debug> {
124    /// Used to schedule bookkeeping tasks every so often.
125    tick: AtomicU32,
126    shared: &'l WorkStealQueue<T>,
127    stealing: AtomicBool,
128    queue: &'l Worker<T>,
129    /// Fast random number generator.
130    rand: FastRand,
131}
132
133impl<T: Debug> Default for LocalQueue<'_, T> {
134    fn default() -> Self {
135        WorkStealQueue::get_instance().local_queue()
136    }
137}
138
139impl<T: Debug> Drop for LocalQueue<'_, T> {
140    fn drop(&mut self) {
141        if !std::thread::panicking() {
142            assert!(self.queue.pop().is_none(), "local queue not empty");
143        }
144    }
145}
146
147impl<'l, T: Debug> LocalQueue<'l, T> {
148    fn new(shared: &'l WorkStealQueue<T>, queue: &'l Worker<T>, rand: FastRand) -> Self {
149        LocalQueue {
150            tick: AtomicU32::new(0),
151            shared,
152            stealing: AtomicBool::new(false),
153            queue,
154            rand,
155        }
156    }
157
158    /// Returns `true` if the local queue is empty.
159    pub fn is_empty(&self) -> bool {
160        self.queue.is_empty()
161    }
162
163    /// Returns `true` if the local queue is full.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// use open_coroutine_queue::WorkStealQueue;
169    ///
170    /// let queue = WorkStealQueue::new(1, 2);
171    /// let local = queue.local_queue();
172    /// assert!(local.is_empty());
173    /// for i in 0..2 {
174    ///     local.push_back(i);
175    /// }
176    /// assert!(local.is_full());
177    /// assert_eq!(local.pop_front(), Some(0));
178    /// assert_eq!(local.len(), 1);
179    /// assert_eq!(local.pop_front(), Some(1));
180    /// assert_eq!(local.pop_front(), None);
181    /// assert!(local.is_empty());
182    /// ```
183    pub fn is_full(&self) -> bool {
184        self.queue.spare_capacity() == 0
185    }
186
187    /// Returns the number of elements in the queue.
188    pub fn len(&self) -> usize {
189        self.queue.capacity() - self.queue.spare_capacity()
190    }
191
192    fn try_lock(&self) -> bool {
193        self.stealing
194            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
195            .is_ok()
196    }
197
198    fn release_lock(&self) {
199        self.stealing.store(false, Ordering::Release);
200    }
201
202    /// If the queue is full, first push half to global,
203    /// then push the item to global.
204    ///
205    /// # Examples
206    ///
207    /// ```
208    /// use open_coroutine_queue::WorkStealQueue;
209    ///
210    /// let queue = WorkStealQueue::new(1, 2);
211    /// let local = queue.local_queue();
212    /// for i in 0..4 {
213    ///     local.push_back(i);
214    /// }
215    /// assert_eq!(local.pop_front(), Some(1));
216    /// assert_eq!(local.pop_front(), Some(3));
217    /// assert_eq!(local.pop_front(), Some(0));
218    /// assert_eq!(local.pop_front(), Some(2));
219    /// assert_eq!(local.pop_front(), None);
220    /// ```
221    pub fn push_back(&self, item: T) {
222        if let Err(item) = self.queue.push(item) {
223            //把本地队列的一半放到全局队列
224            let count = self.len() / 2;
225            for _ in 0..count {
226                if let Some(item) = self.queue.pop() {
227                    self.shared.push(item);
228                }
229            }
230            //直接放到全局队列
231            self.shared.push(item);
232        }
233    }
234
235    /// Increment the tick
236    fn tick(&self) -> u32 {
237        let val = self.tick.fetch_add(1, Ordering::Release);
238        if val == u32::MAX {
239            self.tick.store(0, Ordering::Release);
240            return 0;
241        }
242        val + 1
243    }
244
245    /// If the queue is empty, first try steal from global,
246    /// then try steal from siblings.
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use open_coroutine_queue::WorkStealQueue;
252    ///
253    /// let queue = WorkStealQueue::new(1, 32);
254    /// for i in 0..4 {
255    ///     queue.push(i);
256    /// }
257    /// let local = queue.local_queue();
258    /// for i in 0..4 {
259    ///     assert_eq!(local.pop_front(), Some(i));
260    /// }
261    /// assert_eq!(local.pop_front(), None);
262    /// assert_eq!(queue.pop(), None);
263    /// ```
264    ///
265    /// # Examples
266    /// ```
267    /// use open_coroutine_queue::WorkStealQueue;
268    /// let queue = WorkStealQueue::new(2, 64);
269    /// let local0 = queue.local_queue();
270    /// local0.push_back(2);
271    /// local0.push_back(3);
272    /// local0.push_back(4);
273    /// local0.push_back(5);
274    /// let local1 = queue.local_queue();
275    /// local1.push_back(0);
276    /// local1.push_back(1);
277    /// for i in 0..6 {
278    ///     assert_eq!(local1.pop_front(), Some(i));
279    /// }
280    /// assert_eq!(local0.pop_front(), None);
281    /// assert_eq!(local1.pop_front(), None);
282    /// assert_eq!(queue.pop(), None);
283    /// ```
284    #[allow(clippy::cast_possible_truncation)]
285    pub fn pop_front(&self) -> Option<T> {
286        //每从本地弹出61次,就从全局队列弹出
287        if self.tick() % 61 == 0 {
288            if let Some(val) = self.shared.pop() {
289                return Some(val);
290            }
291        }
292
293        //从本地队列弹出元素
294        if let Some(val) = self.queue.pop() {
295            return Some(val);
296        }
297        if self.try_lock() {
298            //尝试从其他本地队列steal
299            let local_queues = &self.shared.local_queues;
300            let num = local_queues.len();
301            let start = self.rand.fastrand_n(num as u32) as usize;
302            for i in 0..num {
303                let i = (start + i) % num;
304                if let Some(another) = local_queues.get(i) {
305                    if std::ptr::eq(&another, &self.queue) {
306                        //不能偷自己
307                        continue;
308                    }
309                    if another.is_empty() {
310                        //其他队列为空
311                        continue;
312                    }
313                    if self.queue.spare_capacity() == 0 {
314                        //本地队列已满
315                        continue;
316                    }
317                    if another
318                        .stealer()
319                        .steal(self.queue, |n| {
320                            //可偷取的最大长度与本地队列空闲长度做比较
321                            n.min(self.queue.spare_capacity())
322                                //与其他队列当前长度的一半做比较
323                                .min(((another.capacity() - another.spare_capacity()) + 1) / 2)
324                        })
325                        .is_ok()
326                    {
327                        self.release_lock();
328                        return self.queue.pop();
329                    }
330                }
331            }
332            self.release_lock();
333        }
334        //都steal不到,只好从shared里pop
335        self.shared.pop()
336    }
337}