1mod context;
4mod page_size;
5pub(crate) mod stack;
6mod suspension;
7
8use std::cell::{Cell, UnsafeCell};
9use std::panic::{self, AssertUnwindSafe};
10use std::ptr;
11
12use self::context::{Context, Entry};
13use self::stack::StackSize;
14pub use self::suspension::{suspension, JoinHandle, Resumption, Suspension};
15use crate::task;
16
17thread_local! {
18 static COROUTINE: Cell<Option<ptr::NonNull<Coroutine>>> = const { Cell::new(None) };
19 static THREAD_CONTEXT: UnsafeCell<Context> = UnsafeCell::new(Context::empty());
20}
21
22pub(crate) fn try_current() -> Option<ptr::NonNull<Coroutine>> {
23 COROUTINE.with(|p| p.get())
24}
25
26pub(crate) fn current() -> ptr::NonNull<Coroutine> {
27 COROUTINE.with(|p| p.get()).expect("no running coroutine")
28}
29
30struct Scope {
31 co: ptr::NonNull<Coroutine>,
32}
33
34impl Scope {
35 fn enter(co: &Coroutine) -> Scope {
36 COROUTINE.with(|cell| {
37 assert!(cell.get().is_none(), "running coroutine not exited");
38 cell.set(Some(ptr::NonNull::from(co)));
39 });
40 Scope { co: ptr::NonNull::from(co) }
41 }
42}
43
44impl Drop for Scope {
45 fn drop(&mut self) {
46 COROUTINE.with(|cell| {
47 let co = cell.replace(None).expect("no running coroutine");
48 assert!(co == self.co, "running coroutine changed");
49 })
50 }
51}
52
53struct ThisThread;
54
55impl ThisThread {
56 fn context<'a>() -> &'a Context {
57 THREAD_CONTEXT.with(|c| unsafe { &*c.get() })
58 }
59
60 fn context_mut<'a>() -> &'a mut Context {
61 THREAD_CONTEXT.with(|c| unsafe { &mut *c.get() })
62 }
63
64 fn resume(context: &Context) {
65 context.switch(Self::context_mut()).unwrap();
66 }
67
68 fn suspend(context: &mut Context) {
69 Self::context().switch(context).unwrap();
70 }
71
72 fn restore() {
73 Self::context().resume().unwrap();
74 }
75}
76
77#[derive(PartialEq, Eq, Clone, Copy, Debug)]
78pub(crate) enum Status {
79 Running,
80 Aborting,
81 Cancelling,
82 Completed,
83}
84
85impl Status {
86 pub fn into_abort(self) -> Self {
87 match self {
88 Self::Running | Self::Aborting => Self::Aborting,
89 _ => self,
90 }
91 }
92}
93
94pub(crate) struct Coroutine {
95 pub status: Status,
96 context: Option<Box<Context>>,
97 f: Option<Box<dyn FnOnce()>>,
98}
99
100unsafe impl Sync for Coroutine {}
101
102impl Coroutine {
103 pub fn new(f: Box<dyn FnOnce()>, stack_size: StackSize) -> Box<Coroutine> {
104 let mut co = Box::new(Coroutine { f: Option::Some(f), status: Status::Running, context: None });
105 let ptr = co.as_mut() as *mut Coroutine as usize;
106 #[cfg(target_pointer_width = "64")]
107 let (low, high) = (ptr as u32, (ptr >> 32) as u32);
108 #[cfg(target_pointer_width = "32")]
109 let (low, high) = (ptr as u32, 0);
110 let entry = Entry { f: Self::main, arg1: low, arg2: high, stack_size };
111 co.context = Some(Context::new(&entry, None));
112 co
113 }
114
115 extern "C" fn main(low: u32, _high: u32) {
116 #[cfg(target_pointer_width = "64")]
117 let ptr = ((_high as usize) << 32) | low as usize;
118 #[cfg(target_pointer_width = "32")]
119 let ptr = low as usize;
120 let co = unsafe { &mut *(ptr as *const Coroutine as *mut Coroutine) };
121 co.run();
122 co.status = Status::Completed;
123 ThisThread::restore();
124 }
125
126 fn run(&mut self) {
127 let f = self.f.take().expect("no entry function");
128 f();
129 }
130
131 pub fn resume(&mut self) -> Status {
135 let _scope = Scope::enter(self);
136 ThisThread::resume(self.context.as_ref().unwrap());
137 self.status
138 }
139
140 pub fn suspend(&mut self) {
141 ThisThread::suspend(self.context.as_mut().unwrap());
142 }
143
144 pub fn is_cancelling(&self) -> bool {
145 self.status == Status::Cancelling
146 }
147}
148
149pub fn spawn<F, T>(f: F) -> JoinHandle<T>
151where
152 F: FnOnce() -> T,
153 F: 'static,
154 T: 'static,
155{
156 let mut task = task::current();
157 let (suspension, resumption) = suspension();
158 let handle = JoinHandle::new(suspension);
159 let task = unsafe { task.as_mut() };
160 task.spawn(
161 move || {
162 let result = panic::catch_unwind(AssertUnwindSafe(f));
163 resumption.set_result(result);
164 },
165 StackSize::default(),
166 );
167 handle
168}
169
170pub fn yield_now() {
172 let t = unsafe { task::current().as_mut() };
173 let co = current();
174 t.yield_coroutine(co);
175}
176
177#[cfg(test)]
178mod tests {
179 use std::cell::Cell;
180 use std::rc::Rc;
181
182 use pretty_assertions::assert_eq;
183
184 use crate::{coroutine, task};
185
186 #[crate::test(crate = "crate")]
187 fn yield_now() {
188 let five = task::spawn(|| {
189 let value = Rc::new(Cell::new(0));
190 let shared_value = value.clone();
191 coroutine::spawn(move || {
192 shared_value.as_ref().set(5);
193 });
194 coroutine::yield_now();
195 value.as_ref().get()
196 });
197 assert_eq!(5, five.join().unwrap());
198 }
199}