1extern crate spin;
2
3pub mod promise;
4
5use std::any::Any;
6use std::sync::Arc;
7use std::panic::{catch_unwind, AssertUnwindSafe};
8pub use promise::Promise;
9
10#[repr(C)]
11pub(crate) struct CoroutineImpl {
12 _dummy: usize
13}
14
15#[repr(C)]
16pub(crate) struct AnyUserData {
17 _dummy: usize
18}
19
20pub(crate) type AsyncEntry = extern "C" fn (co: *const CoroutineImpl, data: *const AnyUserData);
21
22const REQUIRED_BACKEND_VERSION: i32 = 20;
23
24#[link(name = "unblock_hook", kind = "dylib")]
25extern "C" {
26 pub(crate) fn current_coroutine() -> *const CoroutineImpl;
27 fn launch_co(
28 entry: extern "C" fn (*const CoroutineImpl),
29 user_data: *const AnyUserData
30 );
31 fn extract_co_user_data(
32 co: *const CoroutineImpl
33 ) -> *const AnyUserData;
34 fn co_get_global_event_count() -> usize;
35
36 fn coroutine_yield(
37 co: *const CoroutineImpl
38 );
39
40 fn gtp_enable_work_stealing();
41 fn gtp_disable_work_stealing();
42 fn gtp_get_migration_count() -> i32;
43
44 fn ubh_get_version() -> i32;
45
46 pub(crate) fn coroutine_async_enter(
47 co: *const CoroutineImpl,
48 entry: AsyncEntry,
49 data: *const AnyUserData
50 ) -> *const AnyUserData;
51
52 pub(crate) fn coroutine_async_exit(
53 co: *const CoroutineImpl,
54 data: *const AnyUserData
55 );
56}
57
58struct CoroutineEntry<
59 T: Send + 'static,
60 F: FnOnce() -> T + Send + 'static,
61 E: FnOnce(Result<T, Box<Any + Send>>) + Send + 'static
62> {
63 entry: Option<F>,
64 on_exit: Option<E>
65}
66
67extern "C" fn _launch<
68 T: Send + 'static,
69 F: FnOnce() -> T + Send + 'static,
70 E: FnOnce(Result<T, Box<Any + Send>>) + Send + 'static
71>(co: *const CoroutineImpl) {
72 let mut target = unsafe { Box::from_raw(
73 extract_co_user_data(co) as *const CoroutineEntry<T, F, E> as *mut CoroutineEntry<T, F, E>
74 ) };
75 let entry = target.entry.take().unwrap();
76 let ret = catch_unwind(AssertUnwindSafe(move || (entry)()));
77 (target.on_exit.take().unwrap())(ret);
78}
79
80fn spawn_with_callback<
81 T: Send + 'static,
82 F: FnOnce() -> T + Send + 'static,
83 E: FnOnce(Result<T, Box<Any + Send>>) + Send + 'static
84>(entry: F, cb: E) {
85 let real_version = unsafe { ubh_get_version() };
86 if real_version != REQUIRED_BACKEND_VERSION {
87 panic!("Backend version mismatch (expected: {}, got: {})", REQUIRED_BACKEND_VERSION, real_version);
88 }
89
90 let co = Box::new(CoroutineEntry {
91 entry: Some(entry),
92 on_exit: Some(cb)
93 });
94 unsafe {
95 launch_co(
96 _launch::<T, F, E>,
97 Box::into_raw(co) as *const AnyUserData
98 );
99 }
100}
101
102pub struct JoinHandle<T: Send + 'static> {
104 state: Arc<spin::Mutex<JoinHandleState<T>>>
105}
106
107impl<T: Send + 'static> JoinHandle<T> {
108 fn priv_clone(&self) -> JoinHandle<T> {
109 JoinHandle {
110 state: self.state.clone()
111 }
112 }
113
114 pub fn join(self) -> Result<T, Box<Any + Send>> {
124 Promise::await(move |p| {
125 let mut state = self.state.lock();
126 let result = match ::std::mem::replace(&mut *state, JoinHandleState::Empty) {
127 JoinHandleState::Empty => None,
128 JoinHandleState::Done(v) => Some(v),
129 JoinHandleState::Pending(_) => unreachable!()
130 };
131 if let Some(result) = result {
132 drop(state);
133 p.resolve(result);
134 } else {
135 *state = JoinHandleState::Pending(p);
136 }
137 })
138 }
139}
140
141enum JoinHandleState<T: Send + 'static> {
142 Empty,
143 Done(Result<T, Box<Any + Send>>),
144 Pending(Promise<Result<T, Box<Any + Send>>>)
145}
146
147pub fn fast_spawn<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(entry: F) {
151 spawn_with_callback(entry, |_| {});
152}
153
154pub fn spawn<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(entry: F) -> JoinHandle<T> {
156 let handle = JoinHandle {
157 state: Arc::new(spin::Mutex::new(JoinHandleState::Empty as JoinHandleState<T>))
158 };
159 let handle2 = handle.priv_clone();
160 spawn_with_callback(entry, move |ret| {
161 let mut ret = Some(ret);
162 let mut resolve_target: Option<Promise<Result<T, Box<Any + Send>>>> = None;
163
164 let mut state = handle2.state.lock();
165 let new_state = match ::std::mem::replace(&mut *state, JoinHandleState::Empty) {
166 JoinHandleState::Empty => JoinHandleState::Done(ret.take().unwrap()),
167 JoinHandleState::Pending(p) => {
168 resolve_target = Some(p);
169 JoinHandleState::Empty
170 },
171 JoinHandleState::Done(_) => unreachable!()
172 };
173 *state = new_state;
174 drop(state);
175
176 if let Some(p) = resolve_target {
177 p.resolve(ret.take().unwrap());
178 }
179 });
180 handle
181}
182
183pub fn spawn_inherit<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(entry: F) {
188 if unsafe { current_coroutine() }.is_null() {
189 ::std::thread::spawn(entry);
190 } else {
191 spawn(entry);
192 }
193}
194
195pub fn yield_now() {
198 let co = unsafe { current_coroutine() };
199 if !co.is_null() {
200 unsafe {
201 coroutine_yield(co);
202 }
203 }
204}
205
206pub fn global_event_count() -> usize {
208 unsafe {
209 co_get_global_event_count()
210 }
211}
212
213pub unsafe fn set_work_stealing(enabled: bool) {
219 if enabled {
220 gtp_enable_work_stealing();
221 } else {
222 gtp_disable_work_stealing();
223 }
224}
225
226pub fn migration_count() -> usize {
228 unsafe {
229 gtp_get_migration_count() as usize
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use std::time::Duration;
236
237 #[test]
238 fn test_spawn_join_instant() {
239 super::spawn(move || {
240 let handle = super::spawn(|| {
241 42
242 });
243 ::std::thread::sleep(Duration::from_millis(50));
244 let v: i32 = handle.join().unwrap();
245 assert!(v == 42);
246 }).join().unwrap();
247 }
248
249 #[test]
250 fn test_spawn_join_deferred() {
251 super::spawn(move || {
252 let handle = super::spawn(|| {
253 ::std::thread::sleep(Duration::from_millis(50));
254 42
255 });
256 let v: i32 = handle.join().unwrap();
257 assert!(v == 42);
258 }).join().unwrap();
259 }
260}