base_coroutine/
work_steal.rs1use crate::random::Rng;
2use concurrent_queue::{ConcurrentQueue, PushError};
3use once_cell::sync::{Lazy, OnceCell};
4use st3::fifo::Worker;
5use std::error::Error;
6use std::fmt::{Display, Formatter};
7use std::io::ErrorKind;
8use std::os::raw::c_void;
9use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10
11static mut INSTANCE: Lazy<Queue> = Lazy::new(Queue::default);
12
13pub fn get_queue() -> &'static mut WorkStealQueue {
14 unsafe { INSTANCE.local_queue() }
15}
16
17static mut GLOBAL_LOCK: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(false));
18
19pub(crate) static mut GLOBAL_QUEUE: Lazy<ConcurrentQueue<*mut c_void>> =
20 Lazy::new(ConcurrentQueue::unbounded);
21
22pub(crate) static mut LOCAL_QUEUES: OnceCell<Box<[WorkStealQueue]>> = OnceCell::new();
23
24#[repr(C)]
25#[derive(Debug)]
26struct Queue {
27 index: AtomicUsize,
28}
29
30impl Queue {
31 fn new(local_queues: usize, local_capacity: usize) -> Self {
32 unsafe {
33 LOCAL_QUEUES.get_or_init(|| {
34 (0..local_queues)
35 .map(|_| WorkStealQueue::new(local_capacity))
36 .collect()
37 });
38 }
39 Queue {
40 index: AtomicUsize::new(0),
41 }
42 }
43
44 fn push<T>(&self, item: T) {
47 let ptr = Box::leak(Box::new(item));
48 self.push_raw(ptr as *mut _ as *mut c_void)
49 }
50
51 fn push_raw(&self, ptr: *mut c_void) {
52 unsafe { GLOBAL_QUEUE.push(ptr).unwrap() }
53 }
54
55 fn local_queue(&mut self) -> &mut WorkStealQueue {
56 let index = self.index.fetch_add(1, Ordering::Relaxed);
57 if index == usize::MAX {
58 self.index.store(0, Ordering::Relaxed);
59 }
60 unsafe {
61 LOCAL_QUEUES
62 .get_mut()
63 .unwrap()
64 .get_mut(index % num_cpus::get())
65 .unwrap()
66 }
67 }
68}
69
70impl Default for Queue {
71 fn default() -> Self {
72 Self::new(num_cpus::get(), 256)
73 }
74}
75
76#[derive(Debug)]
78pub enum StealError {
79 CanNotStealSelf,
80 EmptySibling,
81 NoMoreSpare,
82 StealSiblingFailed,
83}
84
85impl Display for StealError {
86 fn fmt(&self, fmt: &mut Formatter) -> std::fmt::Result {
87 match *self {
88 StealError::CanNotStealSelf => write!(fmt, "can not steal self"),
89 StealError::EmptySibling => write!(fmt, "the sibling is empty"),
90 StealError::NoMoreSpare => write!(fmt, "self has no more spare"),
91 StealError::StealSiblingFailed => write!(fmt, "steal from another local queue failed"),
92 }
93 }
94}
95
96impl Error for StealError {
97 fn source(&self) -> Option<&(dyn Error + 'static)> {
98 None
99 }
100}
101
102#[repr(C)]
103#[derive(Debug)]
104pub struct WorkStealQueue {
105 stealing: AtomicBool,
106 queue: Worker<*mut c_void>,
107}
108
109impl WorkStealQueue {
110 fn new(max_capacity: usize) -> Self {
111 WorkStealQueue {
112 stealing: AtomicBool::new(false),
113 queue: Worker::new(max_capacity),
114 }
115 }
116
117 pub fn push_back<T>(&mut self, element: T) -> std::io::Result<()> {
118 let ptr = Box::leak(Box::new(element));
119 self.push_back_raw(ptr as *mut _ as *mut c_void)
120 }
121
122 pub fn push_back_raw(&mut self, ptr: *mut c_void) -> std::io::Result<()> {
123 if let Err(item) = self.queue.push(ptr) {
124 unsafe {
125 let count = self.len() / 2;
127 let half = Worker::new(count);
129 let stealer = self.queue.stealer();
130 let _ = stealer.steal(&half, |_n| count);
131 while !half.is_empty() {
132 let _ = GLOBAL_QUEUE.push(half.pop().unwrap());
133 }
134 GLOBAL_QUEUE.push(item).map_err(|e| match e {
135 PushError::Full(_) => {
136 std::io::Error::new(ErrorKind::Other, "global queue is full")
137 }
138 PushError::Closed(_) => {
139 std::io::Error::new(ErrorKind::Other, "global queue closed")
140 }
141 })?
142 }
143 }
144 Ok(())
145 }
146
147 pub fn is_empty(&self) -> bool {
148 self.queue.is_empty()
149 }
150
151 pub fn len(&self) -> usize {
152 self.capacity() - self.spare()
153 }
154
155 pub fn capacity(&self) -> usize {
156 self.queue.capacity()
157 }
158
159 pub fn spare(&self) -> usize {
160 self.queue.spare_capacity()
161 }
162
163 pub(crate) fn try_lock(&mut self) -> bool {
164 self.stealing
165 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
166 .is_ok()
167 }
168
169 pub(crate) fn release_lock(&mut self) {
170 self.stealing.store(false, Ordering::Relaxed);
171 }
172
173 pub fn pop_front_raw(&mut self) -> Option<*mut c_void> {
175 if let Some(val) = self.queue.pop() {
177 return Some(val);
178 }
179 unsafe {
180 if self.try_lock() {
181 if WorkStealQueue::try_global_lock() {
183 if let Ok(popped_item) = GLOBAL_QUEUE.pop() {
184 self.steal_global(self.queue.capacity() / 2);
185 self.release_lock();
186 return Some(popped_item);
187 }
188 }
189 let local_queues = LOCAL_QUEUES.get_mut().unwrap();
191 let mut indexes = Vec::new();
193 let len = local_queues.len();
194 for i in 0..len {
195 indexes.push(i);
196 }
197 for i in 0..(len / 2) {
198 let random = Rng {
199 state: timer_utils::now(),
200 }
201 .gen_usize_to(len);
202 indexes.swap(i, random);
203 }
204 for i in indexes {
205 let another: &mut WorkStealQueue =
206 local_queues.get_mut(i).expect("get local queue failed!");
207 if self.steal_siblings(another, usize::MAX).is_ok() {
208 self.release_lock();
209 return self.queue.pop();
210 }
211 }
212 self.release_lock();
213 }
214 match GLOBAL_QUEUE.pop() {
215 Ok(item) => Some(item),
216 Err(_) => None,
217 }
218 }
219 }
220
221 pub(crate) fn steal_siblings(
222 &mut self,
223 another: &mut WorkStealQueue,
224 count: usize,
225 ) -> Result<(), StealError> {
226 if std::ptr::eq(&another.queue, &self.queue) {
227 return Err(StealError::CanNotStealSelf);
228 }
229 if another.is_empty() {
230 return Err(StealError::EmptySibling);
231 }
232 let count = (another.len() / 2)
233 .min(self.queue.spare_capacity())
234 .min(count);
235 if count == 0 {
236 return Err(StealError::NoMoreSpare);
237 }
238 another
239 .queue
240 .stealer()
241 .steal(&self.queue, |_n| count)
242 .map_err(|_| StealError::StealSiblingFailed)
243 .map(|_| ())
244 }
245
246 pub(crate) fn try_global_lock() -> bool {
247 unsafe {
248 GLOBAL_LOCK
249 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
250 .is_ok()
251 }
252 }
253
254 pub(crate) fn steal_global(&mut self, count: usize) {
255 unsafe {
256 let count = count.min(self.queue.spare_capacity());
257 for _ in 0..count {
258 match GLOBAL_QUEUE.pop() {
259 Ok(item) => self.queue.push(item).expect("steal to local queue failed!"),
260 Err(_) => break,
261 }
262 }
263 GLOBAL_LOCK.store(false, Ordering::Relaxed);
264 }
265 }
266}
267
268