1use crate::common::beans::BeanFactory;
2use crate::common::constants::{CoroutineState, SyscallState};
3use crate::common::ordered_work_steal::{OrderedLocalQueue, OrderedWorkStealQueue};
4use crate::common::{get_timeout_time, now};
5use crate::coroutine::listener::Listener;
6use crate::coroutine::suspender::Suspender;
7use crate::coroutine::Coroutine;
8use crate::{co, impl_current_for, impl_display_by_debug, impl_for_named};
9use dashmap::DashMap;
10use std::collections::{BinaryHeap, VecDeque};
11use std::ffi::c_longlong;
12use std::io::{Error, ErrorKind};
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::time::Duration;
15
16pub type SchedulableCoroutineState = CoroutineState<(), Option<usize>>;
18
19pub type SchedulableCoroutine<'s> = Coroutine<'s, (), (), Option<usize>>;
21
22pub type SchedulableSuspender<'s> = Suspender<'s, (), ()>;
24
25#[repr(C)]
26#[derive(Debug)]
27struct SuspendItem<'s> {
28 timestamp: u64,
29 coroutine: SchedulableCoroutine<'s>,
30}
31
32impl PartialEq<Self> for SuspendItem<'_> {
33 fn eq(&self, other: &Self) -> bool {
34 self.timestamp.eq(&other.timestamp)
35 }
36}
37
38impl Eq for SuspendItem<'_> {}
39
40impl PartialOrd<Self> for SuspendItem<'_> {
41 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42 Some(self.cmp(other))
43 }
44}
45
46impl Ord for SuspendItem<'_> {
47 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
48 other.timestamp.cmp(&self.timestamp)
50 }
51}
52
53#[repr(C)]
54#[derive(Debug)]
55struct SyscallSuspendItem<'s> {
56 timestamp: u64,
57 co_name: &'s str,
58}
59
60impl PartialEq<Self> for SyscallSuspendItem<'_> {
61 fn eq(&self, other: &Self) -> bool {
62 self.timestamp.eq(&other.timestamp)
63 }
64}
65
66impl Eq for SyscallSuspendItem<'_> {}
67
68impl PartialOrd<Self> for SyscallSuspendItem<'_> {
69 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
70 Some(self.cmp(other))
71 }
72}
73
74impl Ord for SyscallSuspendItem<'_> {
75 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
76 other.timestamp.cmp(&self.timestamp)
78 }
79}
80
81#[repr(C)]
83#[derive(Debug)]
84pub struct Scheduler<'s> {
85 name: String,
86 stack_size: AtomicUsize,
87 listeners: VecDeque<&'s dyn Listener<(), Option<usize>>>,
88 ready: OrderedLocalQueue<'s, SchedulableCoroutine<'s>>,
89 suspend: BinaryHeap<SuspendItem<'s>>,
90 syscall: DashMap<&'s str, SchedulableCoroutine<'s>>,
91 syscall_suspend: BinaryHeap<SyscallSuspendItem<'s>>,
92 results: DashMap<&'s str, Result<Option<usize>, &'s str>>,
93}
94
95impl Default for Scheduler<'_> {
96 fn default() -> Self {
97 Self::new(
98 format!("open-coroutine-scheduler-{:?}", std::thread::current().id()),
99 crate::common::constants::DEFAULT_STACK_SIZE,
100 )
101 }
102}
103
104impl Drop for Scheduler<'_> {
105 fn drop(&mut self) {
106 if std::thread::panicking() {
107 return;
108 }
109 _ = self
110 .try_timed_schedule(Duration::from_secs(30))
111 .unwrap_or_else(|_| panic!("Failed to stop scheduler {} !", self.name()));
112 assert!(
113 self.ready.is_empty(),
114 "There are still coroutines to be carried out in the ready queue:{:#?} !",
115 self.ready
116 );
117 assert!(
118 self.suspend.is_empty(),
119 "There are still coroutines to be carried out in the suspend queue:{:#?} !",
120 self.suspend
121 );
122 assert!(
123 self.syscall.is_empty(),
124 "There are still coroutines to be carried out in the syscall queue:{:#?} !",
125 self.syscall
126 );
127 }
128}
129
130impl_for_named!(Scheduler<'s>);
131
132impl_current_for!(SCHEDULER, Scheduler<'s>);
133
134impl_display_by_debug!(Scheduler<'s>);
135
136impl<'s> Scheduler<'s> {
137 #[must_use]
139 pub fn new(name: String, stack_size: usize) -> Self {
140 Scheduler {
141 name,
142 stack_size: AtomicUsize::new(stack_size),
143 listeners: VecDeque::new(),
144 ready: BeanFactory::get_or_default::<OrderedWorkStealQueue<SchedulableCoroutine>>(
145 crate::common::constants::COROUTINE_GLOBAL_QUEUE_BEAN,
146 )
147 .local_queue(),
148 suspend: BinaryHeap::default(),
149 syscall: DashMap::default(),
150 syscall_suspend: BinaryHeap::default(),
151 results: DashMap::default(),
152 }
153 }
154
155 pub fn name(&self) -> &str {
157 &self.name
158 }
159
160 pub fn stack_size(&self) -> usize {
163 self.stack_size.load(Ordering::Acquire)
164 }
165
166 pub fn submit_co(
174 &self,
175 f: impl FnOnce(&Suspender<(), ()>, ()) -> Option<usize> + 'static,
176 stack_size: Option<usize>,
177 priority: Option<c_longlong>,
178 ) -> std::io::Result<()> {
179 self.submit_raw_co(co!(
180 Some(format!("{}@{}", self.name(), uuid::Uuid::new_v4())),
181 f,
182 Some(stack_size.unwrap_or(self.stack_size())),
183 priority
184 )?)
185 }
186
187 pub fn add_listener(&mut self, listener: impl Listener<(), Option<usize>> + 's) {
189 self.listeners.push_back(Box::leak(Box::new(listener)));
190 }
191
192 pub fn submit_raw_co(&self, mut co: SchedulableCoroutine<'s>) -> std::io::Result<()> {
197 for listener in self.listeners.clone() {
198 co.add_raw_listener(listener);
199 }
200 self.ready.push(co);
201 Ok(())
202 }
203
204 pub fn try_resume(&self, co_name: &'s str) {
212 if let Some((_, co)) = self.syscall.remove(&co_name) {
213 match co.state() {
214 CoroutineState::Syscall(val, syscall, SyscallState::Suspend(_)) => {
215 co.syscall(val, syscall, SyscallState::Callback)
216 .expect("change syscall state failed");
217 }
218 _ => unreachable!("try_resume unexpect CoroutineState"),
219 }
220 self.ready.push(co);
221 }
222 }
223
224 pub fn try_schedule(&mut self) -> std::io::Result<()> {
232 self.try_timeout_schedule(u64::MAX).map(|_| ())
233 }
234
235 pub fn try_timed_schedule(&mut self, dur: Duration) -> std::io::Result<u64> {
243 self.try_timeout_schedule(get_timeout_time(dur))
244 }
245
246 pub fn try_timeout_schedule(&mut self, timeout_time: u64) -> std::io::Result<u64> {
256 Self::init_current(self);
257 let left_time = self.do_schedule(timeout_time);
258 Self::clean_current();
259 left_time
260 }
261
262 fn do_schedule(&mut self, timeout_time: u64) -> std::io::Result<u64> {
263 loop {
264 let left_time = timeout_time.saturating_sub(now());
265 if 0 == left_time {
266 return Ok(0);
267 }
268 self.check_ready()?;
269 if let Some(mut coroutine) = self.ready.pop() {
271 match coroutine.resume()? {
272 CoroutineState::Syscall((), _, state) => {
273 let co_name = Box::leak(Box::from(coroutine.name()));
275 _ = self.syscall.insert(co_name, coroutine);
277 if let SyscallState::Suspend(timestamp) = state {
278 self.syscall_suspend
279 .push(SyscallSuspendItem { timestamp, co_name });
280 }
281 }
282 CoroutineState::Suspend((), timestamp) => {
283 if timestamp > now() {
284 self.suspend.push(SuspendItem {
286 timestamp,
287 coroutine,
288 });
289 } else {
290 self.ready.push(coroutine);
292 }
293 }
294 CoroutineState::Complete(result) => {
295 let co_name = Box::leak(Box::from(coroutine.name()));
296 assert!(
297 self.results.insert(co_name, Ok(result)).is_none(),
298 "not consume result"
299 );
300 }
301 CoroutineState::Error(message) => {
302 let co_name = Box::leak(Box::from(coroutine.name()));
303 assert!(
304 self.results.insert(co_name, Err(message)).is_none(),
305 "not consume result"
306 );
307 }
308 _ => {
309 return Err(Error::new(
310 ErrorKind::Other,
311 "try_timeout_schedule should never execute to here",
312 ));
313 }
314 }
315 continue;
316 }
317 return Ok(left_time);
318 }
319 }
320
321 fn check_ready(&mut self) -> std::io::Result<()> {
322 while let Some(item) = self.suspend.peek() {
324 if now() < item.timestamp {
325 break;
326 }
327 if let Some(item) = self.suspend.pop() {
328 item.coroutine.ready()?;
329 self.ready.push(item.coroutine);
330 }
331 }
332 while let Some(item) = self.syscall_suspend.peek() {
334 if now() < item.timestamp {
335 break;
336 }
337 if let Some(item) = self.syscall_suspend.pop() {
338 if let Some((_, co)) = self.syscall.remove(item.co_name) {
339 match co.state() {
340 CoroutineState::Syscall(val, syscall, SyscallState::Suspend(_)) => {
341 co.syscall(val, syscall, SyscallState::Timeout)?;
342 self.ready.push(co);
343 }
344 _ => unreachable!("check_ready should never execute to here"),
345 }
346 }
347 }
348 }
349 Ok(())
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use crate::scheduler::SyscallSuspendItem;
356 use std::collections::BinaryHeap;
357
358 #[test]
359 fn test_small_heap() {
360 let mut heap = BinaryHeap::default();
361 for timestamp in (0..10).rev() {
362 heap.push(SyscallSuspendItem {
363 timestamp,
364 co_name: "test",
365 });
366 }
367 for timestamp in 0..10 {
368 assert_eq!(timestamp, heap.pop().unwrap().timestamp);
369 }
370 }
371}