base_coroutine/
coroutine.rs

1use crate::context::{Context, Transfer};
2use crate::id::IdGenerator;
3#[cfg(unix)]
4use crate::monitor::Monitor;
5use crate::scheduler::Scheduler;
6use crate::stack::ProtectedFixedSizeStack;
7use crate::stack::StackError::{ExceedsMaximumSize, IoError};
8use std::cell::RefCell;
9use std::marker::PhantomData;
10use std::mem::{ManuallyDrop, MaybeUninit};
11use std::os::raw::c_void;
12
13#[repr(C)]
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15pub enum Status {
16    ///协程被创建
17    Created,
18    ///等待运行
19    Ready,
20    ///运行中
21    Running,
22    ///被挂起
23    Suspend,
24    ///执行系统调用
25    SystemCall,
26    ///栈扩/缩容时
27    CopyStack,
28    ///调用用户函数完成,但未退出
29    Finished,
30    ///已退出
31    Exited,
32}
33
34#[repr(transparent)]
35pub struct Yielder<'a, Param, Yield, Return> {
36    sp: &'a Transfer,
37    marker: PhantomData<fn(Yield) -> CoroutineResult<Param, Return>>,
38}
39
40thread_local! {
41    static DELAY_TIME: Box<RefCell<u64>> = Box::new(RefCell::new(0));
42}
43
44impl<'a, Param, Yield, Return> Yielder<'a, Param, Yield, Return> {
45    /// Suspends the execution of a currently running coroutine.
46    ///
47    /// This function will switch control back to the original caller of
48    /// [`Coroutine::resume`]. This function will then return once the
49    /// [`Coroutine::resume`] function is called again.
50    pub extern "C" fn suspend(&self, val: Yield) -> Param {
51        OpenCoroutine::<Param, Yield, Return>::clean_current();
52        let yielder = OpenCoroutine::<Param, Yield, Return>::yielder();
53        OpenCoroutine::<Param, Yield, Return>::clean_yielder();
54        unsafe {
55            let mut coroutine_result = CoroutineResult::<Yield, Return>::Yield(val);
56            //see Scheduler.do_schedule
57            let transfer = self
58                .sp
59                .context
60                .resume(&mut coroutine_result as *mut _ as usize);
61            OpenCoroutine::init_yielder(&*yielder);
62            let backed = transfer.data as *mut c_void as *mut _
63                as *mut OpenCoroutine<'_, Param, Yield, Return>;
64            std::ptr::read_unaligned(&(*backed).param)
65        }
66    }
67
68    pub(crate) extern "C" fn syscall(&self) {
69        OpenCoroutine::<Param, Yield, Return>::clean_current();
70        let yielder = OpenCoroutine::<Param, Yield, Return>::yielder();
71        OpenCoroutine::<Param, Yield, Return>::clean_yielder();
72        unsafe {
73            let mut coroutine_result = CoroutineResult::<Yield, Return>::SystemCall;
74            //see Scheduler.do_schedule
75            self.sp
76                .context
77                .resume(&mut coroutine_result as *mut _ as usize);
78            OpenCoroutine::init_yielder(&*yielder);
79        }
80    }
81
82    pub extern "C" fn delay(&self, val: Yield, ms_time: u64) -> Param {
83        self.delay_ns(
84            val,
85            match ms_time.checked_mul(1_000_000) {
86                Some(v) => v,
87                None => u64::MAX,
88            },
89        )
90    }
91
92    pub extern "C" fn delay_ns(&self, val: Yield, ns_time: u64) -> Param {
93        Yielder::<Param, Yield, Return>::init_delay_time(ns_time);
94        self.suspend(val)
95    }
96
97    fn init_delay_time(time: u64) {
98        DELAY_TIME.with(|boxed| {
99            *boxed.borrow_mut() = time;
100        });
101    }
102
103    pub(crate) fn delay_time() -> u64 {
104        DELAY_TIME.with(|boxed| *boxed.borrow_mut())
105    }
106
107    pub(crate) fn clean_delay() {
108        DELAY_TIME.with(|boxed| *boxed.borrow_mut() = 0)
109    }
110}
111
112/// Value returned from resuming a coroutine.
113#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
114pub enum CoroutineResult<Yield, Return> {
115    /// Value returned by a coroutine suspending itself with a `Yielder`.
116    Yield(Yield),
117
118    /// Value returned by a coroutine returning from its main function.
119    Return(Return),
120
121    SystemCall,
122}
123
124impl<Yield, Return> CoroutineResult<Yield, Return> {
125    /// Returns the `Yield` value as an `Option<Yield>`.
126    pub fn as_yield(self) -> Option<Yield> {
127        match self {
128            CoroutineResult::Yield(val) => Some(val),
129            CoroutineResult::Return(_) => None,
130            CoroutineResult::SystemCall => None,
131        }
132    }
133
134    /// Returns the `Return` value as an `Option<Return>`.
135    pub fn as_return(self) -> Option<Return> {
136        match self {
137            CoroutineResult::Yield(_) => None,
138            CoroutineResult::Return(val) => Some(val),
139            CoroutineResult::SystemCall => None,
140        }
141    }
142}
143
144pub type UserFunc<'a, Param, Yield, Return> =
145    extern "C" fn(&'a Yielder<Param, Yield, Return>, Param) -> Return;
146
147/// 子协程
148pub type Coroutine<Input, Return> = OpenCoroutine<'static, Input, (), Return>;
149
150thread_local! {
151    static COROUTINE: Box<RefCell<*mut c_void>> = Box::new(RefCell::new(std::ptr::null_mut()));
152    static YIELDER: Box<RefCell<*const c_void>> = Box::new(RefCell::new(std::ptr::null()));
153}
154
155#[repr(C)]
156pub struct OpenCoroutine<'a, Param, Yield, Return> {
157    id: usize,
158    sp: Transfer,
159    stack: ProtectedFixedSizeStack,
160    pub(crate) status: Status,
161    //用户函数
162    proc: UserFunc<'a, Param, Yield, Return>,
163    marker: PhantomData<&'a extern "C" fn(Param) -> CoroutineResult<Yield, Return>>,
164    //调用用户函数的参数
165    param: Param,
166    result: MaybeUninit<ManuallyDrop<Return>>,
167    scheduler: Option<*mut Scheduler>,
168}
169
170impl<'a, Param, Yield, Return> OpenCoroutine<'a, Param, Yield, Return> {
171    extern "C" fn child_context_func(t: Transfer) {
172        let coroutine = unsafe {
173            &mut *(t.data as *mut c_void as *mut _ as *mut OpenCoroutine<'_, Param, Yield, Return>)
174        };
175        let yielder = Yielder {
176            sp: &t,
177            marker: Default::default(),
178        };
179        OpenCoroutine::init_yielder(&yielder);
180        unsafe {
181            coroutine.status = Status::Running;
182            let proc = coroutine.proc;
183            let param = std::ptr::read_unaligned(&coroutine.param);
184            let result = proc(&yielder, param);
185            coroutine.status = Status::Finished;
186            OpenCoroutine::<Param, Yield, Return>::clean_current();
187            OpenCoroutine::<Param, Yield, Return>::clean_yielder();
188            #[cfg(unix)]
189            {
190                //还没执行到10ms就返回了,此时需要清理signal
191                //否则下一个协程执行不到10ms就被抢占调度了
192                Monitor::clean_task(Monitor::signal_time());
193                Monitor::clean_signal_time();
194            }
195            if let Some(scheduler) = coroutine.scheduler {
196                coroutine.result = MaybeUninit::new(ManuallyDrop::new(result));
197                //执行下一个子协程
198                (*scheduler).do_schedule();
199            } else {
200                let mut coroutine_result = CoroutineResult::<Yield, Return>::Return(result);
201                t.context.resume(&mut coroutine_result as *mut _ as usize);
202                unreachable!("should not execute to here !")
203            }
204        }
205    }
206
207    pub fn new(
208        proc: UserFunc<'a, Param, Yield, Return>,
209        param: Param,
210        size: usize,
211    ) -> std::io::Result<Self> {
212        let stack = ProtectedFixedSizeStack::new(size).map_err(|e| match e {
213            ExceedsMaximumSize(size) => std::io::Error::new(
214                std::io::ErrorKind::Other,
215                "Requested more than max size of ".to_owned()
216                    + &size.to_string()
217                    + " bytes for a stack",
218            ),
219            IoError(e) => e,
220        })?;
221        Ok(OpenCoroutine {
222            id: IdGenerator::next_coroutine_id(),
223            sp: Transfer::new(
224                unsafe {
225                    Context::new(
226                        &stack,
227                        OpenCoroutine::<Param, Yield, Return>::child_context_func,
228                    )
229                },
230                0,
231            ),
232            stack,
233            status: Status::Created,
234            proc,
235            marker: Default::default(),
236            param,
237            result: MaybeUninit::uninit(),
238            scheduler: None,
239        })
240    }
241
242    pub fn resume_with(&mut self, val: Param) -> CoroutineResult<Yield, Return> {
243        self.param = val;
244        self.resume()
245    }
246
247    pub fn resume(&mut self) -> CoroutineResult<Yield, Return> {
248        self.status = Status::Ready;
249        self.sp.data = self as *mut _ as usize;
250        unsafe {
251            OpenCoroutine::init_current(self);
252            let transfer = self.sp.context.resume(self.sp.data);
253            //更新sp
254            self.sp.context = transfer.context;
255            std::ptr::read_unaligned(
256                transfer.data as *mut c_void as *mut _ as *mut CoroutineResult<Yield, Return>,
257            )
258        }
259    }
260
261    pub fn get_id(&self) -> usize {
262        self.id
263    }
264
265    pub fn get_status(&self) -> Status {
266        self.status
267    }
268
269    pub fn get_result(&self) -> Option<Return> {
270        if self.get_status() == Status::Finished {
271            unsafe {
272                let mut m = self.result.assume_init_read();
273                Some(ManuallyDrop::take(&mut m))
274            }
275        } else {
276            None
277        }
278    }
279
280    pub fn get_scheduler(&self) -> Option<*mut Scheduler> {
281        self.scheduler
282    }
283
284    pub(crate) fn set_scheduler(&mut self, scheduler: &mut Scheduler) {
285        self.scheduler = Some(scheduler);
286    }
287
288    fn init_yielder(yielder: &Yielder<Param, Yield, Return>) {
289        YIELDER.with(|boxed| {
290            *boxed.borrow_mut() = yielder as *const _ as *const c_void;
291        });
292    }
293
294    pub fn yielder<'y>() -> *const Yielder<'y, Param, Yield, Return> {
295        YIELDER.with(|boxed| unsafe { std::mem::transmute(*boxed.borrow_mut()) })
296    }
297
298    fn clean_yielder() {
299        YIELDER.with(|boxed| *boxed.borrow_mut() = std::ptr::null())
300    }
301
302    fn init_current(coroutine: &mut OpenCoroutine<'a, Param, Yield, Return>) {
303        COROUTINE.with(|boxed| {
304            *boxed.borrow_mut() = coroutine as *mut _ as *mut c_void;
305        })
306    }
307
308    pub fn current<'c>() -> Option<&'a mut OpenCoroutine<'c, Param, Yield, Return>> {
309        COROUTINE.with(|boxed| {
310            let ptr = *boxed.borrow_mut();
311            if ptr.is_null() {
312                None
313            } else {
314                Some(unsafe { &mut *(ptr as *mut OpenCoroutine<Param, Yield, Return>) })
315            }
316        })
317    }
318
319    fn clean_current() {
320        COROUTINE.with(|boxed| *boxed.borrow_mut() = std::ptr::null_mut())
321    }
322}
323
324impl<'a, Param, Yield, Return> Drop for OpenCoroutine<'a, Param, Yield, Return> {
325    fn drop(&mut self) {
326        self.status = Status::Exited;
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use crate::coroutine::{OpenCoroutine, Yielder};
333
334    #[test]
335    fn test_return() {
336        extern "C" fn context_func(_yielder: &Yielder<usize, usize, usize>, input: usize) -> usize {
337            assert_eq!(0, input);
338            1
339        }
340        let mut coroutine =
341            OpenCoroutine::new(context_func, 0, 2048).expect("create coroutine failed !");
342        assert_eq!(1, coroutine.resume_with(0).as_return().unwrap());
343    }
344
345    #[test]
346    fn test_yield_once() {
347        extern "C" fn context_func(yielder: &Yielder<usize, usize, usize>, input: usize) -> usize {
348            assert_eq!(1, input);
349            assert_eq!(3, yielder.suspend(2));
350            6
351        }
352        let mut coroutine =
353            OpenCoroutine::new(context_func, 1, 2048).expect("create coroutine failed !");
354        assert_eq!(2, coroutine.resume_with(1).as_yield().unwrap());
355    }
356
357    #[test]
358    fn test_yield() {
359        extern "C" fn context_func(yielder: &Yielder<usize, usize, usize>, input: usize) -> usize {
360            assert_eq!(1, input);
361            assert_eq!(3, yielder.suspend(2));
362            assert_eq!(5, yielder.suspend(4));
363            6
364        }
365        let mut coroutine =
366            OpenCoroutine::new(context_func, 1, 2048).expect("create coroutine failed !");
367        assert_eq!(2, coroutine.resume_with(1).as_yield().unwrap());
368        assert_eq!(4, coroutine.resume_with(3).as_yield().unwrap());
369        assert_eq!(6, coroutine.resume_with(5).as_return().unwrap());
370    }
371
372    #[test]
373    fn test_current() {
374        extern "C" fn context_func(
375            _yielder: &Yielder<usize, usize, usize>,
376            _input: usize,
377        ) -> usize {
378            assert!(OpenCoroutine::<usize, usize, usize>::current().is_some());
379            1
380        }
381        assert!(OpenCoroutine::<usize, usize, usize>::current().is_none());
382        let mut coroutine =
383            OpenCoroutine::new(context_func, 0, 2048).expect("create coroutine failed !");
384        coroutine.resume_with(0).as_return().unwrap();
385    }
386}