stuck/coroutine/
mod.rs

1//! Cooperative coroutines in task.
2
3mod context;
4mod page_size;
5pub(crate) mod stack;
6mod suspension;
7
8use std::cell::{Cell, UnsafeCell};
9use std::panic::{self, AssertUnwindSafe};
10use std::ptr;
11
12use self::context::{Context, Entry};
13use self::stack::StackSize;
14pub use self::suspension::{suspension, JoinHandle, Resumption, Suspension};
15use crate::task;
16
17thread_local! {
18    static COROUTINE: Cell<Option<ptr::NonNull<Coroutine>>> = const {  Cell::new(None) };
19    static THREAD_CONTEXT: UnsafeCell<Context> = UnsafeCell::new(Context::empty());
20}
21
22pub(crate) fn try_current() -> Option<ptr::NonNull<Coroutine>> {
23    COROUTINE.with(|p| p.get())
24}
25
26pub(crate) fn current() -> ptr::NonNull<Coroutine> {
27    COROUTINE.with(|p| p.get()).expect("no running coroutine")
28}
29
30struct Scope {
31    co: ptr::NonNull<Coroutine>,
32}
33
34impl Scope {
35    fn enter(co: &Coroutine) -> Scope {
36        COROUTINE.with(|cell| {
37            assert!(cell.get().is_none(), "running coroutine not exited");
38            cell.set(Some(ptr::NonNull::from(co)));
39        });
40        Scope { co: ptr::NonNull::from(co) }
41    }
42}
43
44impl Drop for Scope {
45    fn drop(&mut self) {
46        COROUTINE.with(|cell| {
47            let co = cell.replace(None).expect("no running coroutine");
48            assert!(co == self.co, "running coroutine changed");
49        })
50    }
51}
52
53struct ThisThread;
54
55impl ThisThread {
56    fn context<'a>() -> &'a Context {
57        THREAD_CONTEXT.with(|c| unsafe { &*c.get() })
58    }
59
60    fn context_mut<'a>() -> &'a mut Context {
61        THREAD_CONTEXT.with(|c| unsafe { &mut *c.get() })
62    }
63
64    fn resume(context: &Context) {
65        context.switch(Self::context_mut()).unwrap();
66    }
67
68    fn suspend(context: &mut Context) {
69        Self::context().switch(context).unwrap();
70    }
71
72    fn restore() {
73        Self::context().resume().unwrap();
74    }
75}
76
77#[derive(PartialEq, Eq, Clone, Copy, Debug)]
78pub(crate) enum Status {
79    Running,
80    Aborting,
81    Cancelling,
82    Completed,
83}
84
85impl Status {
86    pub fn into_abort(self) -> Self {
87        match self {
88            Self::Running | Self::Aborting => Self::Aborting,
89            _ => self,
90        }
91    }
92}
93
94pub(crate) struct Coroutine {
95    pub status: Status,
96    context: Option<Box<Context>>,
97    f: Option<Box<dyn FnOnce()>>,
98}
99
100unsafe impl Sync for Coroutine {}
101
102impl Coroutine {
103    pub fn new(f: Box<dyn FnOnce()>, stack_size: StackSize) -> Box<Coroutine> {
104        let mut co = Box::new(Coroutine { f: Option::Some(f), status: Status::Running, context: None });
105        let ptr = co.as_mut() as *mut Coroutine as usize;
106        #[cfg(target_pointer_width = "64")]
107        let (low, high) = (ptr as u32, (ptr >> 32) as u32);
108        #[cfg(target_pointer_width = "32")]
109        let (low, high) = (ptr as u32, 0);
110        let entry = Entry { f: Self::main, arg1: low, arg2: high, stack_size };
111        co.context = Some(Context::new(&entry, None));
112        co
113    }
114
115    extern "C" fn main(low: u32, _high: u32) {
116        #[cfg(target_pointer_width = "64")]
117        let ptr = ((_high as usize) << 32) | low as usize;
118        #[cfg(target_pointer_width = "32")]
119        let ptr = low as usize;
120        let co = unsafe { &mut *(ptr as *const Coroutine as *mut Coroutine) };
121        co.run();
122        co.status = Status::Completed;
123        ThisThread::restore();
124    }
125
126    fn run(&mut self) {
127        let f = self.f.take().expect("no entry function");
128        f();
129    }
130
131    /// Resumes coroutine.
132    ///
133    /// Returns whether this coroutine should be resumed again.
134    pub fn resume(&mut self) -> Status {
135        let _scope = Scope::enter(self);
136        ThisThread::resume(self.context.as_ref().unwrap());
137        self.status
138    }
139
140    pub fn suspend(&mut self) {
141        ThisThread::suspend(self.context.as_mut().unwrap());
142    }
143
144    pub fn is_cancelling(&self) -> bool {
145        self.status == Status::Cancelling
146    }
147}
148
149/// Spawns a cooperative task and returns a [JoinHandle] for it.
150pub fn spawn<F, T>(f: F) -> JoinHandle<T>
151where
152    F: FnOnce() -> T,
153    F: 'static,
154    T: 'static,
155{
156    let mut task = task::current();
157    let (suspension, resumption) = suspension();
158    let handle = JoinHandle::new(suspension);
159    let task = unsafe { task.as_mut() };
160    task.spawn(
161        move || {
162            let result = panic::catch_unwind(AssertUnwindSafe(f));
163            resumption.set_result(result);
164        },
165        StackSize::default(),
166    );
167    handle
168}
169
170/// Yields coroutine for next scheduling cycle.
171pub fn yield_now() {
172    let t = unsafe { task::current().as_mut() };
173    let co = current();
174    t.yield_coroutine(co);
175}
176
177#[cfg(test)]
178mod tests {
179    use std::cell::Cell;
180    use std::rc::Rc;
181
182    use pretty_assertions::assert_eq;
183
184    use crate::{coroutine, task};
185
186    #[crate::test(crate = "crate")]
187    fn yield_now() {
188        let five = task::spawn(|| {
189            let value = Rc::new(Cell::new(0));
190            let shared_value = value.clone();
191            coroutine::spawn(move || {
192                shared_value.as_ref().set(5);
193            });
194            coroutine::yield_now();
195            value.as_ref().get()
196        });
197        assert_eq!(5, five.join().unwrap());
198    }
199}