base_coroutine/
scheduler.rs

1use crate::coroutine::{Coroutine, CoroutineResult, OpenCoroutine, Status, UserFunc, Yielder};
2use crate::id::IdGenerator;
3#[cfg(unix)]
4use crate::monitor::Monitor;
5use crate::stack::Stack;
6use crate::work_steal::{get_queue, WorkStealQueue};
7use object_collection::{ObjectList, ObjectMap};
8use once_cell::sync::Lazy;
9use std::cell::RefCell;
10use std::os::raw::c_void;
11use std::time::Duration;
12use timer_utils::TimerList;
13
14thread_local! {
15    static YIELDER: Box<RefCell<*const c_void>> = Box::new(RefCell::new(std::ptr::null()));
16    static TIMEOUT_TIME: Box<RefCell<u64>> = Box::new(RefCell::new(0));
17}
18
19/// 主协程
20type MainCoroutine<'a> = OpenCoroutine<'a, *mut Scheduler, (), ()>;
21
22static mut SYSTEM_CALL_TABLE: Lazy<ObjectMap<usize>> = Lazy::new(ObjectMap::new);
23
24static mut SUSPEND_TABLE: Lazy<TimerList> = Lazy::new(TimerList::new);
25
26#[repr(C)]
27#[derive(Debug)]
28pub struct Scheduler {
29    id: usize,
30    ready: &'static mut WorkStealQueue,
31    //not support for now
32    copy_stack: ObjectList,
33}
34
35impl Scheduler {
36    pub fn new() -> Self {
37        Scheduler {
38            id: IdGenerator::next_scheduler_id(),
39            ready: get_queue(),
40            copy_stack: ObjectList::new(),
41        }
42    }
43
44    pub fn current<'a>() -> Option<&'a mut Scheduler> {
45        if let Some(co) = Coroutine::<&'static mut c_void, &'static mut c_void>::current() {
46            if let Some(ptr) = co.get_scheduler() {
47                return Some(unsafe { &mut *ptr });
48            }
49        }
50        None
51    }
52
53    fn init_yielder(yielder: &Yielder<*mut Scheduler, (), ()>) {
54        YIELDER.with(|boxed| {
55            *boxed.borrow_mut() = yielder as *const _ as *const c_void;
56        });
57    }
58
59    fn yielder<'a>() -> *const Yielder<'a, *mut Scheduler, (), ()> {
60        YIELDER.with(|boxed| unsafe { std::mem::transmute(*boxed.borrow_mut()) })
61    }
62
63    fn clean_yielder() {
64        YIELDER.with(|boxed| *boxed.borrow_mut() = std::ptr::null())
65    }
66
67    fn init_timeout_time(timeout_time: u64) {
68        TIMEOUT_TIME.with(|boxed| {
69            *boxed.borrow_mut() = timeout_time;
70        });
71    }
72
73    fn timeout_time() -> u64 {
74        TIMEOUT_TIME.with(|boxed| *boxed.borrow_mut())
75    }
76
77    fn clean_time() {
78        TIMEOUT_TIME.with(|boxed| *boxed.borrow_mut() = 0)
79    }
80
81    pub fn submit(
82        &mut self,
83        f: UserFunc<&'static mut c_void, (), &'static mut c_void>,
84        val: &'static mut c_void,
85        size: usize,
86    ) -> std::io::Result<&'static Coroutine<&'static mut c_void, &'static mut c_void>> {
87        let mut coroutine = Coroutine::new(f, val, size)?;
88        coroutine.status = Status::Ready;
89        coroutine.set_scheduler(self);
90        let ptr = Box::leak(Box::new(coroutine));
91        self.ready.push_back_raw(ptr as *mut _ as *mut c_void)?;
92        Ok(ptr)
93    }
94
95    pub fn timed_schedule(&mut self, timeout: Duration) -> std::io::Result<()> {
96        let timeout_time = timer_utils::get_timeout_time(timeout);
97        while !self.ready.is_empty()
98            || unsafe { !SUSPEND_TABLE.is_empty() || !SYSTEM_CALL_TABLE.is_empty() }
99            || !self.copy_stack.is_empty()
100        {
101            if timeout_time <= timer_utils::now() {
102                break;
103            }
104            self.try_timeout_schedule(timeout_time)?;
105        }
106        Ok(())
107    }
108
109    pub fn try_schedule(&mut self) -> std::io::Result<()> {
110        self.try_timeout_schedule(Duration::MAX.as_secs())
111    }
112
113    pub fn try_timed_schedule(&mut self, time: Duration) -> std::io::Result<()> {
114        self.try_timeout_schedule(timer_utils::get_timeout_time(time))
115    }
116
117    pub fn try_timeout_schedule(&mut self, timeout_time: u64) -> std::io::Result<()> {
118        Scheduler::init_timeout_time(timeout_time);
119        extern "C" fn main_context_func(
120            yielder: &Yielder<*mut Scheduler, (), ()>,
121            scheduler: *mut Scheduler,
122        ) {
123            Scheduler::init_yielder(yielder);
124            unsafe { (*scheduler).do_schedule() };
125            unreachable!("should not execute to here !")
126        }
127        let mut main = MainCoroutine::new(main_context_func, self, Stack::default_size())?;
128        assert_eq!(main.resume(), CoroutineResult::Yield(()));
129        Scheduler::clean_time();
130        Ok(())
131    }
132
133    fn back_to_main() {
134        //跳回主线程
135        let yielder = Scheduler::yielder();
136        Scheduler::clean_yielder();
137        if !yielder.is_null() {
138            unsafe {
139                (*yielder).suspend(());
140            }
141        }
142    }
143
144    pub(crate) fn do_schedule(&mut self) {
145        if Scheduler::timeout_time() <= timer_utils::now() {
146            Scheduler::back_to_main()
147        }
148        let _ = self.check_ready();
149        match self.ready.pop_front_raw() {
150            Some(pointer) => {
151                let coroutine = unsafe {
152                    &mut *(pointer as *mut Coroutine<&'static mut c_void, &'static mut c_void>)
153                };
154                let _start = timer_utils::get_timeout_time(Duration::from_millis(10));
155                #[cfg(unix)]
156                {
157                    Monitor::init_signal_time(_start);
158                    Monitor::add_task(_start);
159                }
160                //see OpenCoroutine::child_context_func
161                match coroutine.resume() {
162                    CoroutineResult::Yield(()) => {
163                        let delay_time =
164                            Yielder::<&'static mut c_void, (), &'static mut c_void>::delay_time();
165                        if delay_time > 0 {
166                            //挂起协程到时间轮
167                            coroutine.status = Status::Suspend;
168                            unsafe {
169                                SUSPEND_TABLE.insert_raw(
170                                    timer_utils::add_timeout_time(delay_time),
171                                    coroutine as *mut _ as *mut c_void,
172                                );
173                            }
174                            Yielder::<&'static mut c_void, (), &'static mut c_void>::clean_delay();
175                        } else {
176                            //放入就绪队列尾部
177                            let _ = self.ready.push_back_raw(coroutine as *mut _ as *mut c_void);
178                        }
179                    }
180                    CoroutineResult::Return(_) => unreachable!("never have a result"),
181                    CoroutineResult::SystemCall => {
182                        coroutine.status = Status::SystemCall;
183                        unsafe { SYSTEM_CALL_TABLE.insert(coroutine.get_id(), coroutine) };
184                    }
185                };
186                #[cfg(unix)]
187                {
188                    //还没执行到10ms就主动yield了,此时需要清理signal
189                    //否则下一个协程执行不到10ms就被抢占调度了
190                    Monitor::clean_task(_start);
191                    Monitor::clean_signal_time();
192                }
193                self.do_schedule();
194            }
195            None => Scheduler::back_to_main(),
196        }
197    }
198
199    fn check_ready(&mut self) -> std::io::Result<()> {
200        unsafe {
201            for _ in 0..SUSPEND_TABLE.len() {
202                if let Some(entry) = SUSPEND_TABLE.front() {
203                    let exec_time = entry.get_time();
204                    if timer_utils::now() < exec_time {
205                        break;
206                    }
207                    //移动至"就绪"队列
208                    if let Some(mut entry) = SUSPEND_TABLE.pop_front() {
209                        for _ in 0..entry.len() {
210                            if let Some(pointer) = entry.pop_front_raw() {
211                                let coroutine = &mut *(pointer
212                                    as *mut Coroutine<&'static mut c_void, &'static mut c_void>);
213                                coroutine.status = Status::Ready;
214                                //把到时间的协程加入就绪队列
215                                self.ready
216                                    .push_back_raw(coroutine as *mut _ as *mut c_void)?
217                            }
218                        }
219                    }
220                }
221            }
222            Ok(())
223        }
224    }
225
226    /// 用户不应该使用此方法
227    pub fn syscall(&mut self) {
228        //挂起当前协程
229        let yielder = Coroutine::<&'static mut c_void, &'static mut c_void>::yielder();
230        if !yielder.is_null() {
231            unsafe { (*yielder).syscall() };
232        }
233    }
234
235    /// 用户不应该使用此方法
236    #[allow(clippy::missing_safety_doc)]
237    pub unsafe fn resume(&mut self, co_id: usize) -> std::io::Result<()> {
238        if let Some(co) = SYSTEM_CALL_TABLE.remove(&co_id) {
239            self.ready.push_back_raw(co)?;
240        }
241        Ok(())
242    }
243}
244
245impl Default for Scheduler {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use crate::coroutine::Yielder;
254    use crate::scheduler::Scheduler;
255    use std::os::raw::c_void;
256    use std::thread;
257    use std::time::Duration;
258
259    fn null() -> &'static mut c_void {
260        unsafe { std::mem::transmute(10usize) }
261    }
262
263    #[test]
264    fn simplest() {
265        let mut scheduler = Scheduler::new();
266        extern "C" fn f1(
267            _yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
268            _input: &'static mut c_void,
269        ) -> &'static mut c_void {
270            println!("[coroutine1] launched");
271            null()
272        }
273        scheduler.submit(f1, null(), 4096).expect("submit failed !");
274        extern "C" fn f2(
275            _yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
276            _input: &'static mut c_void,
277        ) -> &'static mut c_void {
278            println!("[coroutine2] launched");
279            null()
280        }
281        scheduler.submit(f2, null(), 4096).expect("submit failed !");
282        scheduler.try_schedule().expect("try_schedule failed !");
283    }
284
285    #[test]
286    fn with_suspend() {
287        let mut scheduler = Scheduler::new();
288        extern "C" fn suspend1(
289            yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
290            _input: &'static mut c_void,
291        ) -> &'static mut c_void {
292            println!("[coroutine1] suspend");
293            yielder.suspend(());
294            println!("[coroutine1] back");
295            null()
296        }
297        scheduler
298            .submit(suspend1, null(), 4096)
299            .expect("submit failed !");
300        extern "C" fn suspend2(
301            yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
302            _input: &'static mut c_void,
303        ) -> &'static mut c_void {
304            println!("[coroutine2] suspend");
305            yielder.suspend(());
306            println!("[coroutine2] back");
307            null()
308        }
309        scheduler
310            .submit(suspend2, null(), 4096)
311            .expect("submit failed !");
312        scheduler.try_schedule().expect("try_schedule failed !");
313    }
314
315    extern "C" fn delay(
316        yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
317        _input: &'static mut c_void,
318    ) -> &'static mut c_void {
319        println!("[coroutine] delay");
320        yielder.delay((), 100);
321        println!("[coroutine] back");
322        null()
323    }
324
325    #[test]
326    fn with_delay() {
327        let mut scheduler = Scheduler::new();
328        scheduler
329            .submit(delay, null(), 4096)
330            .expect("submit failed !");
331        scheduler.try_schedule().expect("try_schedule failed !");
332        thread::sleep(Duration::from_millis(100));
333        scheduler.try_schedule().expect("try_schedule failed !");
334    }
335
336    #[test]
337    fn timed_schedule() {
338        let mut scheduler = Scheduler::new();
339        scheduler
340            .submit(delay, null(), 4096)
341            .expect("submit failed !");
342        scheduler
343            .timed_schedule(Duration::from_millis(200))
344            .expect("try_schedule failed !");
345    }
346
347    #[cfg(unix)]
348    #[test]
349    fn preemptive_schedule() {
350        static mut FLAG: bool = true;
351        let handler = std::thread::spawn(|| {
352            let mut scheduler = Scheduler::new();
353            extern "C" fn f1(
354                _yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
355                _input: &'static mut c_void,
356            ) -> &'static mut c_void {
357                unsafe {
358                    while FLAG {
359                        println!("loop");
360                        std::thread::sleep(Duration::from_millis(10));
361                    }
362                }
363                null()
364            }
365            scheduler.submit(f1, null(), 4096).expect("submit failed !");
366            extern "C" fn f2(
367                _yielder: &Yielder<&'static mut c_void, (), &'static mut c_void>,
368                _input: &'static mut c_void,
369            ) -> &'static mut c_void {
370                unsafe {
371                    FLAG = false;
372                }
373                null()
374            }
375            scheduler.submit(f2, null(), 4096).expect("submit failed !");
376            scheduler.try_schedule().expect("try_schedule failed !");
377        });
378        unsafe {
379            handler.join().unwrap();
380            assert!(!FLAG);
381        }
382    }
383}