base_coroutine/
work_steal.rs

1use crate::random::Rng;
2use concurrent_queue::{ConcurrentQueue, PushError};
3use once_cell::sync::{Lazy, OnceCell};
4use st3::fifo::Worker;
5use std::error::Error;
6use std::fmt::{Display, Formatter};
7use std::io::ErrorKind;
8use std::os::raw::c_void;
9use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10
11static mut INSTANCE: Lazy<Queue> = Lazy::new(Queue::default);
12
13pub fn get_queue() -> &'static mut WorkStealQueue {
14    unsafe { INSTANCE.local_queue() }
15}
16
17static mut GLOBAL_LOCK: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(false));
18
19pub(crate) static mut GLOBAL_QUEUE: Lazy<ConcurrentQueue<*mut c_void>> =
20    Lazy::new(ConcurrentQueue::unbounded);
21
22pub(crate) static mut LOCAL_QUEUES: OnceCell<Box<[WorkStealQueue]>> = OnceCell::new();
23
24#[repr(C)]
25#[derive(Debug)]
26struct Queue {
27    index: AtomicUsize,
28}
29
30impl Queue {
31    fn new(local_queues: usize, local_capacity: usize) -> Self {
32        unsafe {
33            LOCAL_QUEUES.get_or_init(|| {
34                (0..local_queues)
35                    .map(|_| WorkStealQueue::new(local_capacity))
36                    .collect()
37            });
38        }
39        Queue {
40            index: AtomicUsize::new(0),
41        }
42    }
43
44    /// Push an item to the global queue. When one of the local queues empties, they can pick this
45    /// item up.
46    fn push<T>(&self, item: T) {
47        let ptr = Box::leak(Box::new(item));
48        self.push_raw(ptr as *mut _ as *mut c_void)
49    }
50
51    fn push_raw(&self, ptr: *mut c_void) {
52        unsafe { GLOBAL_QUEUE.push(ptr).unwrap() }
53    }
54
55    fn local_queue(&mut self) -> &mut WorkStealQueue {
56        let index = self.index.fetch_add(1, Ordering::Relaxed);
57        if index == usize::MAX {
58            self.index.store(0, Ordering::Relaxed);
59        }
60        unsafe {
61            LOCAL_QUEUES
62                .get_mut()
63                .unwrap()
64                .get_mut(index % num_cpus::get())
65                .unwrap()
66        }
67    }
68}
69
70impl Default for Queue {
71    fn default() -> Self {
72        Self::new(num_cpus::get(), 256)
73    }
74}
75
76/// Error type returned by steal methods.
77#[derive(Debug)]
78pub enum StealError {
79    CanNotStealSelf,
80    EmptySibling,
81    NoMoreSpare,
82    StealSiblingFailed,
83}
84
85impl Display for StealError {
86    fn fmt(&self, fmt: &mut Formatter) -> std::fmt::Result {
87        match *self {
88            StealError::CanNotStealSelf => write!(fmt, "can not steal self"),
89            StealError::EmptySibling => write!(fmt, "the sibling is empty"),
90            StealError::NoMoreSpare => write!(fmt, "self has no more spare"),
91            StealError::StealSiblingFailed => write!(fmt, "steal from another local queue failed"),
92        }
93    }
94}
95
96impl Error for StealError {
97    fn source(&self) -> Option<&(dyn Error + 'static)> {
98        None
99    }
100}
101
102#[repr(C)]
103#[derive(Debug)]
104pub struct WorkStealQueue {
105    stealing: AtomicBool,
106    queue: Worker<*mut c_void>,
107}
108
109impl WorkStealQueue {
110    fn new(max_capacity: usize) -> Self {
111        WorkStealQueue {
112            stealing: AtomicBool::new(false),
113            queue: Worker::new(max_capacity),
114        }
115    }
116
117    pub fn push_back<T>(&mut self, element: T) -> std::io::Result<()> {
118        let ptr = Box::leak(Box::new(element));
119        self.push_back_raw(ptr as *mut _ as *mut c_void)
120    }
121
122    pub fn push_back_raw(&mut self, ptr: *mut c_void) -> std::io::Result<()> {
123        if let Err(item) = self.queue.push(ptr) {
124            unsafe {
125                //把本地队列的一半放到全局队列
126                let count = self.len() / 2;
127                //todo 这里实际上可以减少一次copy
128                let half = Worker::new(count);
129                let stealer = self.queue.stealer();
130                let _ = stealer.steal(&half, |_n| count);
131                while !half.is_empty() {
132                    let _ = GLOBAL_QUEUE.push(half.pop().unwrap());
133                }
134                GLOBAL_QUEUE.push(item).map_err(|e| match e {
135                    PushError::Full(_) => {
136                        std::io::Error::new(ErrorKind::Other, "global queue is full")
137                    }
138                    PushError::Closed(_) => {
139                        std::io::Error::new(ErrorKind::Other, "global queue closed")
140                    }
141                })?
142            }
143        }
144        Ok(())
145    }
146
147    pub fn is_empty(&self) -> bool {
148        self.queue.is_empty()
149    }
150
151    pub fn len(&self) -> usize {
152        self.capacity() - self.spare()
153    }
154
155    pub fn capacity(&self) -> usize {
156        self.queue.capacity()
157    }
158
159    pub fn spare(&self) -> usize {
160        self.queue.spare_capacity()
161    }
162
163    pub(crate) fn try_lock(&mut self) -> bool {
164        self.stealing
165            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
166            .is_ok()
167    }
168
169    pub(crate) fn release_lock(&mut self) {
170        self.stealing.store(false, Ordering::Relaxed);
171    }
172
173    /// 如果是闭包,还是要获取裸指针再手动转换,不然类型有问题
174    pub fn pop_front_raw(&mut self) -> Option<*mut c_void> {
175        //优先从本地队列弹出元素
176        if let Some(val) = self.queue.pop() {
177            return Some(val);
178        }
179        unsafe {
180            if self.try_lock() {
181                //尝试从全局队列steal
182                if WorkStealQueue::try_global_lock() {
183                    if let Ok(popped_item) = GLOBAL_QUEUE.pop() {
184                        self.steal_global(self.queue.capacity() / 2);
185                        self.release_lock();
186                        return Some(popped_item);
187                    }
188                }
189                //尝试从其他本地队列steal
190                let local_queues = LOCAL_QUEUES.get_mut().unwrap();
191                //这里生成一个打乱顺序的数组,遍历获取index
192                let mut indexes = Vec::new();
193                let len = local_queues.len();
194                for i in 0..len {
195                    indexes.push(i);
196                }
197                for i in 0..(len / 2) {
198                    let random = Rng {
199                        state: timer_utils::now(),
200                    }
201                    .gen_usize_to(len);
202                    indexes.swap(i, random);
203                }
204                for i in indexes {
205                    let another: &mut WorkStealQueue =
206                        local_queues.get_mut(i).expect("get local queue failed!");
207                    if self.steal_siblings(another, usize::MAX).is_ok() {
208                        self.release_lock();
209                        return self.queue.pop();
210                    }
211                }
212                self.release_lock();
213            }
214            match GLOBAL_QUEUE.pop() {
215                Ok(item) => Some(item),
216                Err(_) => None,
217            }
218        }
219    }
220
221    pub(crate) fn steal_siblings(
222        &mut self,
223        another: &mut WorkStealQueue,
224        count: usize,
225    ) -> Result<(), StealError> {
226        if std::ptr::eq(&another.queue, &self.queue) {
227            return Err(StealError::CanNotStealSelf);
228        }
229        if another.is_empty() {
230            return Err(StealError::EmptySibling);
231        }
232        let count = (another.len() / 2)
233            .min(self.queue.spare_capacity())
234            .min(count);
235        if count == 0 {
236            return Err(StealError::NoMoreSpare);
237        }
238        another
239            .queue
240            .stealer()
241            .steal(&self.queue, |_n| count)
242            .map_err(|_| StealError::StealSiblingFailed)
243            .map(|_| ())
244    }
245
246    pub(crate) fn try_global_lock() -> bool {
247        unsafe {
248            GLOBAL_LOCK
249                .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
250                .is_ok()
251        }
252    }
253
254    pub(crate) fn steal_global(&mut self, count: usize) {
255        unsafe {
256            let count = count.min(self.queue.spare_capacity());
257            for _ in 0..count {
258                match GLOBAL_QUEUE.pop() {
259                    Ok(item) => self.queue.push(item).expect("steal to local queue failed!"),
260                    Err(_) => break,
261                }
262            }
263            GLOBAL_LOCK.store(false, Ordering::Relaxed);
264        }
265    }
266}
267
268// #[cfg(test)]
269// mod tests {
270//     use super::*;
271//     use std::os::raw::c_void;
272//
273//     #[test]
274//     fn steal_global() {
275//         for i in 0..16 {
276//             unsafe {
277//                 INSTANCE.push_raw(i as *mut c_void);
278//             }
279//         }
280//         let local = get_queue();
281//         for i in 0..16 {
282//             assert_eq!(local.pop_front_raw().unwrap(), i as *mut c_void);
283//         }
284//         assert_eq!(local.pop_front_raw(), None);
285//     }
286//
287//     #[test]
288//     fn steal_siblings() {
289//         unsafe {
290//             INSTANCE.push_raw(2 as *mut c_void);
291//             INSTANCE.push_raw(3 as *mut c_void);
292//         }
293//
294//         let local0 = get_queue();
295//         local0.push_back_raw(4 as *mut c_void);
296//         local0.push_back_raw(5 as *mut c_void);
297//         local0.push_back_raw(6 as *mut c_void);
298//         local0.push_back_raw(7 as *mut c_void);
299//
300//         let local1 = get_queue();
301//         local1.push_back_raw(0 as *mut c_void);
302//         local1.push_back_raw(1 as *mut c_void);
303//         for i in 0..7 {
304//             assert_eq!(local1.pop_front_raw().unwrap(), i as *mut c_void);
305//         }
306//     }
307// }