1use crate::common::beans::BeanFactory;
2use crate::common::constants::{CoroutineState, SyscallState};
3use crate::common::ordered_work_steal::{OrderedLocalQueue, OrderedWorkStealQueue};
4use crate::common::{get_timeout_time, now};
5use crate::coroutine::listener::Listener;
6use crate::coroutine::suspender::Suspender;
7use crate::coroutine::Coroutine;
8use crate::{co, impl_current_for, impl_display_by_debug, impl_for_named, warn};
9use dashmap::{DashMap, DashSet};
10#[cfg(unix)]
11use nix::sys::pthread::Pthread;
12use once_cell::sync::Lazy;
13use std::collections::{BinaryHeap, HashMap, VecDeque};
14use std::ffi::c_longlong;
15use std::io::Error;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::time::Duration;
18
19pub type SchedulableCoroutineState = CoroutineState<(), Option<usize>>;
21
22pub type SchedulableCoroutine<'s> = Coroutine<'s, (), (), Option<usize>>;
24
25pub type SchedulableSuspender<'s> = Suspender<'s, (), ()>;
27
28#[repr(C)]
29#[derive(Debug)]
30struct SuspendItem<'s> {
31 timestamp: u64,
32 coroutine: SchedulableCoroutine<'s>,
33}
34
35impl PartialEq<Self> for SuspendItem<'_> {
36 fn eq(&self, other: &Self) -> bool {
37 self.timestamp.eq(&other.timestamp)
38 }
39}
40
41impl Eq for SuspendItem<'_> {}
42
43impl PartialOrd<Self> for SuspendItem<'_> {
44 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49impl Ord for SuspendItem<'_> {
50 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
51 other.timestamp.cmp(&self.timestamp)
53 }
54}
55
56#[repr(C)]
57#[derive(Debug)]
58struct SyscallSuspendItem<'s> {
59 timestamp: u64,
60 co_name: &'s str,
61}
62
63impl PartialEq<Self> for SyscallSuspendItem<'_> {
64 fn eq(&self, other: &Self) -> bool {
65 self.timestamp.eq(&other.timestamp)
66 }
67}
68
69impl Eq for SyscallSuspendItem<'_> {}
70
71impl PartialOrd<Self> for SyscallSuspendItem<'_> {
72 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
73 Some(self.cmp(other))
74 }
75}
76
77impl Ord for SyscallSuspendItem<'_> {
78 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
79 other.timestamp.cmp(&self.timestamp)
81 }
82}
83
84#[cfg(unix)]
85static RUNNING_COROUTINES: Lazy<DashMap<&str, Pthread>> = Lazy::new(DashMap::new);
86#[cfg(windows)]
87static RUNNING_COROUTINES: Lazy<DashMap<&str, usize>> = Lazy::new(DashMap::new);
88
89static CANCEL_COROUTINES: Lazy<DashSet<&str>> = Lazy::new(DashSet::new);
90
91#[repr(C)]
93#[derive(Debug)]
94pub struct Scheduler<'s> {
95 name: String,
96 stack_size: AtomicUsize,
97 listeners: VecDeque<&'s dyn Listener<(), Option<usize>>>,
98 #[doc = include_str!("../docs/en/ordered-work-steal.md")]
99 ready: OrderedLocalQueue<'s, SchedulableCoroutine<'s>>,
100 suspend: BinaryHeap<SuspendItem<'s>>,
101 syscall: DashMap<&'s str, SchedulableCoroutine<'s>>,
102 syscall_suspend: BinaryHeap<SyscallSuspendItem<'s>>,
103}
104
105impl Default for Scheduler<'_> {
106 fn default() -> Self {
107 Self::new(
108 format!("open-coroutine-scheduler-{:?}", std::thread::current().id()),
109 crate::common::constants::DEFAULT_STACK_SIZE,
110 )
111 }
112}
113
114impl Drop for Scheduler<'_> {
115 fn drop(&mut self) {
116 if std::thread::panicking() {
117 return;
118 }
119 let name = self.name.clone();
120 _ = self
121 .try_timed_schedule(Duration::from_secs(30))
122 .unwrap_or_else(|e| panic!("Failed to stop scheduler {name} due to {e} !"));
123 assert!(
124 self.ready.is_empty(),
125 "There are still coroutines to be carried out in the ready queue:{:#?} !",
126 self.ready
127 );
128 assert!(
129 self.suspend.is_empty(),
130 "There are still coroutines to be carried out in the suspend queue:{:#?} !",
131 self.suspend
132 );
133 assert!(
134 self.syscall.is_empty(),
135 "There are still coroutines to be carried out in the syscall queue:{:#?} !",
136 self.syscall
137 );
138 }
139}
140
141impl_for_named!(Scheduler<'s>);
142
143impl_current_for!(SCHEDULER, Scheduler<'s>);
144
145impl_display_by_debug!(Scheduler<'s>);
146
147#[allow(clippy::type_complexity)]
148impl<'s> Scheduler<'s> {
149 #[must_use]
151 pub fn new(name: String, stack_size: usize) -> Self {
152 Scheduler {
153 name,
154 stack_size: AtomicUsize::new(stack_size),
155 listeners: VecDeque::new(),
156 ready: BeanFactory::get_or_default::<OrderedWorkStealQueue<SchedulableCoroutine>>(
157 crate::common::constants::COROUTINE_GLOBAL_QUEUE_BEAN,
158 )
159 .local_queue(),
160 suspend: BinaryHeap::default(),
161 syscall: DashMap::default(),
162 syscall_suspend: BinaryHeap::default(),
163 }
164 }
165
166 pub fn name(&self) -> &str {
168 &self.name
169 }
170
171 pub fn stack_size(&self) -> usize {
174 self.stack_size.load(Ordering::Acquire)
175 }
176
177 pub fn submit_co(
185 &self,
186 f: impl FnOnce(&Suspender<(), ()>, ()) -> Option<usize> + 'static,
187 stack_size: Option<usize>,
188 priority: Option<c_longlong>,
189 ) -> std::io::Result<String> {
190 self.submit_raw_co(co!(
191 Some(format!("{}@{}", self.name(), uuid::Uuid::new_v4())),
192 f,
193 Some(stack_size.unwrap_or(self.stack_size())),
194 priority
195 )?)
196 }
197
198 pub fn add_listener(&mut self, listener: impl Listener<(), Option<usize>> + 's) {
200 self.listeners.push_back(Box::leak(Box::new(listener)));
201 }
202
203 pub fn submit_raw_co(&self, mut co: SchedulableCoroutine<'s>) -> std::io::Result<String> {
208 for listener in self.listeners.clone() {
209 co.add_raw_listener(listener);
210 }
211 let co_name = co.name().to_string();
212 self.ready.push(co);
213 Ok(co_name)
214 }
215
216 pub fn try_resume(&self, co_name: &'s str) {
224 if let Some((_, co)) = self.syscall.remove(&co_name) {
225 match co.state() {
226 CoroutineState::Syscall(val, syscall, SyscallState::Suspend(_)) => {
227 co.syscall(val, syscall, SyscallState::Callback)
228 .expect("change syscall state failed");
229 }
230 _ => unreachable!("try_resume unexpect CoroutineState"),
231 }
232 self.ready.push(co);
233 }
234 }
235
236 pub fn try_schedule(&mut self) -> std::io::Result<HashMap<&str, Result<Option<usize>, &str>>> {
244 self.try_timeout_schedule(u64::MAX)
245 .map(|(_, results)| results)
246 }
247
248 pub fn try_timed_schedule(
256 &mut self,
257 dur: Duration,
258 ) -> std::io::Result<(u64, HashMap<&str, Result<Option<usize>, &str>>)> {
259 self.try_timeout_schedule(get_timeout_time(dur))
260 }
261
262 pub fn try_timeout_schedule(
272 &mut self,
273 timeout_time: u64,
274 ) -> std::io::Result<(u64, HashMap<&str, Result<Option<usize>, &str>>)> {
275 Self::init_current(self);
276 let r = self.do_schedule(timeout_time);
277 Self::clean_current();
278 r
279 }
280
281 fn do_schedule(
282 &mut self,
283 timeout_time: u64,
284 ) -> std::io::Result<(u64, HashMap<&str, Result<Option<usize>, &str>>)> {
285 let mut results = HashMap::new();
286 loop {
287 let left_time = timeout_time.saturating_sub(now());
288 if 0 == left_time {
289 return Ok((0, results));
290 }
291 self.check_ready()?;
292 if let Some(mut coroutine) = self.ready.pop() {
294 let co_name = coroutine.name().to_string().leak();
295 if CANCEL_COROUTINES.contains(co_name) {
296 _ = CANCEL_COROUTINES.remove(co_name);
297 warn!("Cancel coroutine:{} successfully !", co_name);
298 continue;
299 }
300 cfg_if::cfg_if! {
301 if #[cfg(windows)] {
302 let current_thread = unsafe {
303 windows_sys::Win32::System::Threading::GetCurrentThread()
304 } as usize;
305 } else {
306 let current_thread = nix::sys::pthread::pthread_self();
307 }
308 }
309 _ = RUNNING_COROUTINES.insert(co_name, current_thread);
310 match coroutine.resume().inspect(|_| {
311 _ = RUNNING_COROUTINES.remove(co_name);
312 })? {
313 CoroutineState::Syscall((), _, state) => {
314 _ = self.syscall.insert(co_name, coroutine);
317 if let SyscallState::Suspend(timestamp) = state {
318 self.syscall_suspend
319 .push(SyscallSuspendItem { timestamp, co_name });
320 }
321 }
322 CoroutineState::Suspend((), timestamp) => {
323 if timestamp > now() {
324 self.suspend.push(SuspendItem {
326 timestamp,
327 coroutine,
328 });
329 } else {
330 self.ready.push(coroutine);
332 }
333 }
334 CoroutineState::Cancelled => {}
335 CoroutineState::Complete(result) => {
336 assert!(
337 results.insert(co_name, Ok(result)).is_none(),
338 "not consume result"
339 );
340 }
341 CoroutineState::Error(message) => {
342 assert!(
343 results.insert(co_name, Err(message)).is_none(),
344 "not consume result"
345 );
346 }
347 _ => {
348 return Err(Error::other(
349 "try_timeout_schedule should never execute to here",
350 ));
351 }
352 }
353 continue;
354 }
355 return Ok((left_time, results));
356 }
357 }
358
359 fn check_ready(&mut self) -> std::io::Result<()> {
360 while let Some(item) = self.suspend.peek() {
362 if now() < item.timestamp {
363 break;
364 }
365 if let Some(item) = self.suspend.pop() {
366 item.coroutine.ready()?;
367 self.ready.push(item.coroutine);
368 }
369 }
370 while let Some(item) = self.syscall_suspend.peek() {
372 if now() < item.timestamp {
373 break;
374 }
375 if let Some(item) = self.syscall_suspend.pop() {
376 if let Some((_, co)) = self.syscall.remove(item.co_name) {
377 match co.state() {
378 CoroutineState::Syscall(val, syscall, SyscallState::Suspend(_)) => {
379 co.syscall(val, syscall, SyscallState::Timeout)?;
380 self.ready.push(co);
381 }
382 _ => unreachable!("check_ready should never execute to here"),
383 }
384 }
385 }
386 }
387 Ok(())
388 }
389
390 pub fn try_cancel_coroutine(co_name: &str) {
392 _ = CANCEL_COROUTINES.insert(Box::leak(Box::from(co_name)));
393 }
394
395 #[cfg(unix)]
397 pub fn get_scheduling_thread(co_name: &str) -> Option<Pthread> {
398 let co_name: &str = Box::leak(Box::from(co_name));
399 RUNNING_COROUTINES.get(co_name).map(|r| *r)
400 }
401
402 #[cfg(windows)]
404 pub fn get_scheduling_thread(co_name: &str) -> Option<windows_sys::Win32::Foundation::HANDLE> {
405 let co_name: &str = Box::leak(Box::from(co_name));
406 RUNNING_COROUTINES
407 .get(co_name)
408 .map(|r| *r as windows_sys::Win32::Foundation::HANDLE)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use crate::scheduler::SyscallSuspendItem;
415 use std::collections::BinaryHeap;
416
417 #[test]
418 fn test_small_heap() {
419 let mut heap = BinaryHeap::default();
420 for timestamp in (0..10).rev() {
421 heap.push(SyscallSuspendItem {
422 timestamp,
423 co_name: "test",
424 });
425 }
426 for timestamp in 0..10 {
427 assert_eq!(timestamp, heap.pop().unwrap().timestamp);
428 }
429 }
430}