coroutines/
lib.rs

1extern crate spin;
2
3pub mod promise;
4
5use std::any::Any;
6use std::sync::Arc;
7use std::panic::{catch_unwind, AssertUnwindSafe};
8pub use promise::Promise;
9
10#[repr(C)]
11pub(crate) struct CoroutineImpl {
12    _dummy: usize
13}
14
15#[repr(C)]
16pub(crate) struct AnyUserData {
17    _dummy: usize
18}
19
20pub(crate) type AsyncEntry = extern "C" fn (co: *const CoroutineImpl, data: *const AnyUserData);
21
22const REQUIRED_BACKEND_VERSION: i32 = 20;
23
24#[link(name = "unblock_hook", kind = "dylib")]
25extern "C" {
26    pub(crate) fn current_coroutine() -> *const CoroutineImpl;
27    fn launch_co(
28        entry: extern "C" fn (*const CoroutineImpl),
29        user_data: *const AnyUserData
30    );
31    fn extract_co_user_data(
32        co: *const CoroutineImpl
33    ) -> *const AnyUserData;
34    fn co_get_global_event_count() -> usize;
35
36    fn coroutine_yield(
37        co: *const CoroutineImpl
38    );
39
40    fn gtp_enable_work_stealing();
41    fn gtp_disable_work_stealing();
42    fn gtp_get_migration_count() -> i32;
43
44    fn ubh_get_version() -> i32;
45
46    pub(crate) fn coroutine_async_enter(
47        co: *const CoroutineImpl,
48        entry: AsyncEntry,
49        data: *const AnyUserData
50    ) -> *const AnyUserData;
51    
52    pub(crate) fn coroutine_async_exit(
53        co: *const CoroutineImpl,
54        data: *const AnyUserData
55    );
56}
57
58struct CoroutineEntry<
59    T: Send + 'static,
60    F: FnOnce() -> T + Send + 'static,
61    E: FnOnce(Result<T, Box<Any + Send>>) + Send + 'static
62> {
63    entry: Option<F>,
64    on_exit: Option<E>
65}
66
67extern "C" fn _launch<
68    T: Send + 'static,
69    F: FnOnce() -> T + Send + 'static,
70    E: FnOnce(Result<T, Box<Any + Send>>) + Send + 'static
71>(co: *const CoroutineImpl) {
72    let mut target = unsafe { Box::from_raw(
73        extract_co_user_data(co) as *const CoroutineEntry<T, F, E> as *mut CoroutineEntry<T, F, E>
74    ) };
75    let entry = target.entry.take().unwrap();
76    let ret = catch_unwind(AssertUnwindSafe(move || (entry)()));
77    (target.on_exit.take().unwrap())(ret);
78}
79
80fn spawn_with_callback<
81    T: Send + 'static,
82    F: FnOnce() -> T + Send + 'static,
83    E: FnOnce(Result<T, Box<Any + Send>>) + Send + 'static
84>(entry: F, cb: E) {
85    let real_version = unsafe { ubh_get_version() };
86    if real_version != REQUIRED_BACKEND_VERSION {
87        panic!("Backend version mismatch (expected: {}, got: {})", REQUIRED_BACKEND_VERSION, real_version);
88    }
89
90    let co = Box::new(CoroutineEntry {
91        entry: Some(entry),
92        on_exit: Some(cb)
93    });
94    unsafe {
95        launch_co(
96            _launch::<T, F, E>,
97            Box::into_raw(co) as *const AnyUserData
98        );
99    }
100}
101
102/// A handle used to wait on a coroutine's termination.
103pub struct JoinHandle<T: Send + 'static> {
104    state: Arc<spin::Mutex<JoinHandleState<T>>>
105}
106
107impl<T: Send + 'static> JoinHandle<T> {
108    fn priv_clone(&self) -> JoinHandle<T> {
109        JoinHandle {
110            state: self.state.clone()
111        }
112    }
113
114    /// Waits for the associated coroutine to finish.
115    ///
116    /// If the associated coroutine has already terminated,
117    /// `join` returns instantly with the result.
118    /// Otherwise, `join` waits until the coroutine terminates.
119    ///
120    /// If the child coroutine panics, `Err` is returned with the
121    /// boxed value passed to `panic`. Otherwise, `Ok` is returned
122    /// with the return value of the closure executed in the coroutine.
123    pub fn join(self) -> Result<T, Box<Any + Send>> {
124        Promise::await(move |p| {
125            let mut state = self.state.lock();
126            let result = match ::std::mem::replace(&mut *state, JoinHandleState::Empty) {
127                JoinHandleState::Empty => None,
128                JoinHandleState::Done(v) => Some(v),
129                JoinHandleState::Pending(_) => unreachable!()
130            };
131            if let Some(result) = result {
132                drop(state);
133                p.resolve(result);
134            } else {
135                *state = JoinHandleState::Pending(p);
136            }
137        })
138    }
139}
140
141enum JoinHandleState<T: Send + 'static> {
142    Empty,
143    Done(Result<T, Box<Any + Send>>),
144    Pending(Promise<Result<T, Box<Any + Send>>>)
145}
146
147/// Spawns a coroutine without building a `JoinHandle`.
148///
149/// This may be faster than `spawn` in some cases.
150pub fn fast_spawn<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(entry: F) {
151    spawn_with_callback(entry, |_| {});
152}
153
154/// Spawns a coroutine and returns its `JoinHandle`.
155pub fn spawn<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(entry: F) -> JoinHandle<T> {
156    let handle = JoinHandle {
157        state: Arc::new(spin::Mutex::new(JoinHandleState::Empty as JoinHandleState<T>))
158    };
159    let handle2 = handle.priv_clone();
160    spawn_with_callback(entry, move |ret| {
161        let mut ret = Some(ret);
162        let mut resolve_target: Option<Promise<Result<T, Box<Any + Send>>>> = None;
163
164        let mut state = handle2.state.lock();
165        let new_state = match ::std::mem::replace(&mut *state, JoinHandleState::Empty) {
166            JoinHandleState::Empty => JoinHandleState::Done(ret.take().unwrap()),
167            JoinHandleState::Pending(p) => {
168                resolve_target = Some(p);
169                JoinHandleState::Empty
170            },
171            JoinHandleState::Done(_) => unreachable!()
172        };
173        *state = new_state;
174        drop(state);
175
176        if let Some(p) = resolve_target {
177            p.resolve(ret.take().unwrap());
178        }
179    });
180    handle
181}
182
183/// Deprecated.
184///
185/// Spawns another coroutine if called inside a coroutine,
186/// or an OS thread otherwise.
187pub fn spawn_inherit<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(entry: F) {
188    if unsafe { current_coroutine() }.is_null() {
189        ::std::thread::spawn(entry);
190    } else {
191        spawn(entry);
192    }
193}
194
195/// Yields out of the current coroutine and allows the scheduler
196/// to execute other coroutines.
197pub fn yield_now() {
198    let co = unsafe { current_coroutine() };
199    if !co.is_null() {
200        unsafe {
201            coroutine_yield(co);
202        }
203    }
204}
205
206/// Returns the global event count.
207pub fn global_event_count() -> usize {
208    unsafe {
209        co_get_global_event_count()
210    }
211}
212
213/// Enable / disable coroutine migration ("work stealing")
214/// globally. (disabled by default)
215///
216/// This should be used with care because migrating values of
217/// non-Send types might break Rust's safety guarantee.
218pub unsafe fn set_work_stealing(enabled: bool) {
219    if enabled {
220        gtp_enable_work_stealing();
221    } else {
222        gtp_disable_work_stealing();
223    }
224}
225
226/// Returns the current global migration count.
227pub fn migration_count() -> usize {
228    unsafe {
229        gtp_get_migration_count() as usize
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use std::time::Duration;
236
237    #[test]
238    fn test_spawn_join_instant() {
239        super::spawn(move || {
240            let handle = super::spawn(|| {
241                42
242            });
243            ::std::thread::sleep(Duration::from_millis(50));
244            let v: i32 = handle.join().unwrap();
245            assert!(v == 42);
246        }).join().unwrap();
247    }
248
249    #[test]
250    fn test_spawn_join_deferred() {
251        super::spawn(move || {
252            let handle = super::spawn(|| {
253                ::std::thread::sleep(Duration::from_millis(50));
254                42
255            });
256            let v: i32 = handle.join().unwrap();
257            assert!(v == 42);
258        }).join().unwrap();
259    }
260}