open_coroutine_core/
scheduler.rs

1use crate::common::beans::BeanFactory;
2use crate::common::constants::{CoroutineState, SyscallState};
3use crate::common::ordered_work_steal::{OrderedLocalQueue, OrderedWorkStealQueue};
4use crate::common::{get_timeout_time, now};
5use crate::coroutine::listener::Listener;
6use crate::coroutine::suspender::Suspender;
7use crate::coroutine::Coroutine;
8use crate::{co, impl_current_for, impl_display_by_debug, impl_for_named, warn};
9use dashmap::{DashMap, DashSet};
10#[cfg(unix)]
11use nix::sys::pthread::Pthread;
12use once_cell::sync::Lazy;
13use std::collections::{BinaryHeap, HashMap, VecDeque};
14use std::ffi::c_longlong;
15use std::io::Error;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::time::Duration;
18
19/// A type for Scheduler.
20pub type SchedulableCoroutineState = CoroutineState<(), Option<usize>>;
21
22/// A type for Scheduler.
23pub type SchedulableCoroutine<'s> = Coroutine<'s, (), (), Option<usize>>;
24
25/// A type for Scheduler.
26pub type SchedulableSuspender<'s> = Suspender<'s, (), ()>;
27
28#[repr(C)]
29#[derive(Debug)]
30struct SuspendItem<'s> {
31    timestamp: u64,
32    coroutine: SchedulableCoroutine<'s>,
33}
34
35impl PartialEq<Self> for SuspendItem<'_> {
36    fn eq(&self, other: &Self) -> bool {
37        self.timestamp.eq(&other.timestamp)
38    }
39}
40
41impl Eq for SuspendItem<'_> {}
42
43impl PartialOrd<Self> for SuspendItem<'_> {
44    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49impl Ord for SuspendItem<'_> {
50    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
51        // BinaryHeap defaults to a large top heap, but we need a small top heap
52        other.timestamp.cmp(&self.timestamp)
53    }
54}
55
56#[repr(C)]
57#[derive(Debug)]
58struct SyscallSuspendItem<'s> {
59    timestamp: u64,
60    co_name: &'s str,
61}
62
63impl PartialEq<Self> for SyscallSuspendItem<'_> {
64    fn eq(&self, other: &Self) -> bool {
65        self.timestamp.eq(&other.timestamp)
66    }
67}
68
69impl Eq for SyscallSuspendItem<'_> {}
70
71impl PartialOrd<Self> for SyscallSuspendItem<'_> {
72    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
73        Some(self.cmp(other))
74    }
75}
76
77impl Ord for SyscallSuspendItem<'_> {
78    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
79        // BinaryHeap defaults to a large top heap, but we need a small top heap
80        other.timestamp.cmp(&self.timestamp)
81    }
82}
83
84#[cfg(unix)]
85static RUNNING_COROUTINES: Lazy<DashMap<&str, Pthread>> = Lazy::new(DashMap::new);
86#[cfg(windows)]
87static RUNNING_COROUTINES: Lazy<DashMap<&str, usize>> = Lazy::new(DashMap::new);
88
89static CANCEL_COROUTINES: Lazy<DashSet<&str>> = Lazy::new(DashSet::new);
90
91/// The scheduler impls.
92#[repr(C)]
93#[derive(Debug)]
94pub struct Scheduler<'s> {
95    name: String,
96    stack_size: AtomicUsize,
97    listeners: VecDeque<&'s dyn Listener<(), Option<usize>>>,
98    #[doc = include_str!("../docs/en/ordered-work-steal.md")]
99    ready: OrderedLocalQueue<'s, SchedulableCoroutine<'s>>,
100    suspend: BinaryHeap<SuspendItem<'s>>,
101    syscall: DashMap<&'s str, SchedulableCoroutine<'s>>,
102    syscall_suspend: BinaryHeap<SyscallSuspendItem<'s>>,
103}
104
105impl Default for Scheduler<'_> {
106    fn default() -> Self {
107        Self::new(
108            format!("open-coroutine-scheduler-{:?}", std::thread::current().id()),
109            crate::common::constants::DEFAULT_STACK_SIZE,
110        )
111    }
112}
113
114impl Drop for Scheduler<'_> {
115    fn drop(&mut self) {
116        if std::thread::panicking() {
117            return;
118        }
119        let name = self.name.clone();
120        _ = self
121            .try_timed_schedule(Duration::from_secs(30))
122            .unwrap_or_else(|e| panic!("Failed to stop scheduler {name} due to {e} !"));
123        assert!(
124            self.ready.is_empty(),
125            "There are still coroutines to be carried out in the ready queue:{:#?} !",
126            self.ready
127        );
128        assert!(
129            self.suspend.is_empty(),
130            "There are still coroutines to be carried out in the suspend queue:{:#?} !",
131            self.suspend
132        );
133        assert!(
134            self.syscall.is_empty(),
135            "There are still coroutines to be carried out in the syscall queue:{:#?} !",
136            self.syscall
137        );
138    }
139}
140
141impl_for_named!(Scheduler<'s>);
142
143impl_current_for!(SCHEDULER, Scheduler<'s>);
144
145impl_display_by_debug!(Scheduler<'s>);
146
147#[allow(clippy::type_complexity)]
148impl<'s> Scheduler<'s> {
149    /// Creates a new scheduler.
150    #[must_use]
151    pub fn new(name: String, stack_size: usize) -> Self {
152        Scheduler {
153            name,
154            stack_size: AtomicUsize::new(stack_size),
155            listeners: VecDeque::new(),
156            ready: BeanFactory::get_or_default::<OrderedWorkStealQueue<SchedulableCoroutine>>(
157                crate::common::constants::COROUTINE_GLOBAL_QUEUE_BEAN,
158            )
159            .local_queue(),
160            suspend: BinaryHeap::default(),
161            syscall: DashMap::default(),
162            syscall_suspend: BinaryHeap::default(),
163        }
164    }
165
166    /// Get the name of this scheduler.
167    pub fn name(&self) -> &str {
168        &self.name
169    }
170
171    /// Get the default stack size for the coroutines in this scheduler.
172    /// If it has not been set, it will be [`crate::common::constants::DEFAULT_STACK_SIZE`].
173    pub fn stack_size(&self) -> usize {
174        self.stack_size.load(Ordering::Acquire)
175    }
176
177    /// Submit a closure to create new coroutine, then the coroutine will be push into ready queue.
178    ///
179    /// Allow multiple threads to concurrently submit coroutine to the scheduler,
180    /// but only allow one thread to execute scheduling.
181    ///
182    /// # Errors
183    /// if create coroutine fails.
184    pub fn submit_co(
185        &self,
186        f: impl FnOnce(&Suspender<(), ()>, ()) -> Option<usize> + 'static,
187        stack_size: Option<usize>,
188        priority: Option<c_longlong>,
189    ) -> std::io::Result<String> {
190        self.submit_raw_co(co!(
191            Some(format!("{}@{}", self.name(), uuid::Uuid::new_v4())),
192            f,
193            Some(stack_size.unwrap_or(self.stack_size())),
194            priority
195        )?)
196    }
197
198    /// Add a listener to this scheduler.
199    pub fn add_listener(&mut self, listener: impl Listener<(), Option<usize>> + 's) {
200        self.listeners.push_back(Box::leak(Box::new(listener)));
201    }
202
203    /// Submit a raw coroutine, then the coroutine will be push into ready queue.
204    ///
205    /// Allow multiple threads to concurrently submit coroutine to the scheduler,
206    /// but only allow one thread to execute scheduling.
207    pub fn submit_raw_co(&self, mut co: SchedulableCoroutine<'s>) -> std::io::Result<String> {
208        for listener in self.listeners.clone() {
209            co.add_raw_listener(listener);
210        }
211        let co_name = co.name().to_string();
212        self.ready.push(co);
213        Ok(co_name)
214    }
215
216    /// Resume a coroutine from the syscall table to the ready queue,
217    /// it's generally only required for framework level crates.
218    ///
219    /// If we can't find the coroutine, nothing happens.
220    ///
221    /// # Errors
222    /// if change to ready fails.
223    pub fn try_resume(&self, co_name: &'s str) {
224        if let Some((_, co)) = self.syscall.remove(&co_name) {
225            match co.state() {
226                CoroutineState::Syscall(val, syscall, SyscallState::Suspend(_)) => {
227                    co.syscall(val, syscall, SyscallState::Callback)
228                        .expect("change syscall state failed");
229                }
230                _ => unreachable!("try_resume unexpect CoroutineState"),
231            }
232            self.ready.push(co);
233        }
234    }
235
236    /// Schedule the coroutines.
237    ///
238    /// Allow multiple threads to concurrently submit coroutine to the scheduler,
239    /// but only allow one thread to execute scheduling.
240    ///
241    /// # Errors
242    /// see `try_timeout_schedule`.
243    pub fn try_schedule(&mut self) -> std::io::Result<HashMap<&str, Result<Option<usize>, &str>>> {
244        self.try_timeout_schedule(u64::MAX)
245            .map(|(_, results)| results)
246    }
247
248    /// Try scheduling the coroutines for up to `dur`.
249    ///
250    /// Allow multiple threads to concurrently submit coroutine to the scheduler,
251    /// but only allow one thread to execute scheduling.
252    ///
253    /// # Errors
254    /// see `try_timeout_schedule`.
255    pub fn try_timed_schedule(
256        &mut self,
257        dur: Duration,
258    ) -> std::io::Result<(u64, HashMap<&str, Result<Option<usize>, &str>>)> {
259        self.try_timeout_schedule(get_timeout_time(dur))
260    }
261
262    /// Attempt to schedule the coroutines before the `timeout_time` timestamp.
263    ///
264    /// Allow multiple threads to concurrently submit coroutine to the scheduler,
265    /// but only allow one thread to schedule.
266    ///
267    /// Returns the left time in ns.
268    ///
269    /// # Errors
270    /// if change to ready fails.
271    pub fn try_timeout_schedule(
272        &mut self,
273        timeout_time: u64,
274    ) -> std::io::Result<(u64, HashMap<&str, Result<Option<usize>, &str>>)> {
275        Self::init_current(self);
276        let r = self.do_schedule(timeout_time);
277        Self::clean_current();
278        r
279    }
280
281    fn do_schedule(
282        &mut self,
283        timeout_time: u64,
284    ) -> std::io::Result<(u64, HashMap<&str, Result<Option<usize>, &str>>)> {
285        let mut results = HashMap::new();
286        loop {
287            let left_time = timeout_time.saturating_sub(now());
288            if 0 == left_time {
289                return Ok((0, results));
290            }
291            self.check_ready()?;
292            // schedule coroutines
293            if let Some(mut coroutine) = self.ready.pop() {
294                let co_name = coroutine.name().to_string().leak();
295                if CANCEL_COROUTINES.contains(co_name) {
296                    _ = CANCEL_COROUTINES.remove(co_name);
297                    warn!("Cancel coroutine:{} successfully !", co_name);
298                    continue;
299                }
300                cfg_if::cfg_if! {
301                    if #[cfg(windows)] {
302                        let current_thread = unsafe {
303                            windows_sys::Win32::System::Threading::GetCurrentThread()
304                        } as usize;
305                    } else {
306                        let current_thread = nix::sys::pthread::pthread_self();
307                    }
308                }
309                _ = RUNNING_COROUTINES.insert(co_name, current_thread);
310                match coroutine.resume().inspect(|_| {
311                    _ = RUNNING_COROUTINES.remove(co_name);
312                })? {
313                    CoroutineState::Syscall((), _, state) => {
314                        //挂起协程到系统调用表
315                        //如果已包含,说明当前系统调用还有上层父系统调用,因此直接忽略插入结果
316                        _ = self.syscall.insert(co_name, coroutine);
317                        if let SyscallState::Suspend(timestamp) = state {
318                            self.syscall_suspend
319                                .push(SyscallSuspendItem { timestamp, co_name });
320                        }
321                    }
322                    CoroutineState::Suspend((), timestamp) => {
323                        if timestamp > now() {
324                            //挂起协程到时间轮
325                            self.suspend.push(SuspendItem {
326                                timestamp,
327                                coroutine,
328                            });
329                        } else {
330                            //放入就绪队列尾部
331                            self.ready.push(coroutine);
332                        }
333                    }
334                    CoroutineState::Cancelled => {}
335                    CoroutineState::Complete(result) => {
336                        assert!(
337                            results.insert(co_name, Ok(result)).is_none(),
338                            "not consume result"
339                        );
340                    }
341                    CoroutineState::Error(message) => {
342                        assert!(
343                            results.insert(co_name, Err(message)).is_none(),
344                            "not consume result"
345                        );
346                    }
347                    _ => {
348                        return Err(Error::other(
349                            "try_timeout_schedule should never execute to here",
350                        ));
351                    }
352                }
353                continue;
354            }
355            return Ok((left_time, results));
356        }
357    }
358
359    fn check_ready(&mut self) -> std::io::Result<()> {
360        // Check if the elements in the suspend queue are ready
361        while let Some(item) = self.suspend.peek() {
362            if now() < item.timestamp {
363                break;
364            }
365            if let Some(item) = self.suspend.pop() {
366                item.coroutine.ready()?;
367                self.ready.push(item.coroutine);
368            }
369        }
370        // Check if the elements in the syscall suspend queue are ready
371        while let Some(item) = self.syscall_suspend.peek() {
372            if now() < item.timestamp {
373                break;
374            }
375            if let Some(item) = self.syscall_suspend.pop() {
376                if let Some((_, co)) = self.syscall.remove(item.co_name) {
377                    match co.state() {
378                        CoroutineState::Syscall(val, syscall, SyscallState::Suspend(_)) => {
379                            co.syscall(val, syscall, SyscallState::Timeout)?;
380                            self.ready.push(co);
381                        }
382                        _ => unreachable!("check_ready should never execute to here"),
383                    }
384                }
385            }
386        }
387        Ok(())
388    }
389
390    /// Cancel the coroutine by name.
391    pub fn try_cancel_coroutine(co_name: &str) {
392        _ = CANCEL_COROUTINES.insert(Box::leak(Box::from(co_name)));
393    }
394
395    /// Get the scheduling thread of the coroutine.
396    #[cfg(unix)]
397    pub fn get_scheduling_thread(co_name: &str) -> Option<Pthread> {
398        let co_name: &str = Box::leak(Box::from(co_name));
399        RUNNING_COROUTINES.get(co_name).map(|r| *r)
400    }
401
402    /// Get the scheduling thread of the coroutine.
403    #[cfg(windows)]
404    pub fn get_scheduling_thread(co_name: &str) -> Option<windows_sys::Win32::Foundation::HANDLE> {
405        let co_name: &str = Box::leak(Box::from(co_name));
406        RUNNING_COROUTINES
407            .get(co_name)
408            .map(|r| *r as windows_sys::Win32::Foundation::HANDLE)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use crate::scheduler::SyscallSuspendItem;
415    use std::collections::BinaryHeap;
416
417    #[test]
418    fn test_small_heap() {
419        let mut heap = BinaryHeap::default();
420        for timestamp in (0..10).rev() {
421            heap.push(SyscallSuspendItem {
422                timestamp,
423                co_name: "test",
424            });
425        }
426        for timestamp in 0..10 {
427            assert_eq!(timestamp, heap.pop().unwrap().timestamp);
428        }
429    }
430}