stuck/coroutine/
suspension.rs

1use std::any::Any;
2use std::cell::{Cell, UnsafeCell};
3use std::rc::Rc;
4use std::{mem, panic, ptr};
5
6use ignore_result::Ignore;
7use static_assertions::assert_not_impl_any;
8
9use super::Coroutine;
10use crate::error::{JoinError, PanicError};
11use crate::select::{Identifier, Permit, PermitReader, Selectable, Selector, TrySelectError};
12use crate::task::{self, Yielding};
13
14enum SuspensionState<T: 'static> {
15    Empty,
16    Value(T),
17    Panicked(PanicError),
18    Joining(ptr::NonNull<Coroutine>),
19    Selector(Selector),
20    Joined,
21}
22
23struct SuspensionJoint<T: 'static> {
24    state: UnsafeCell<SuspensionState<T>>,
25    wakers: Cell<usize>,
26}
27
28impl<T> Yielding for SuspensionJoint<T> {
29    fn interrupt(&self, reason: &'static str) -> bool {
30        self.cancel(PanicError::Static(reason));
31        true
32    }
33}
34
35impl<T> SuspensionJoint<T> {
36    fn new() -> Rc<SuspensionJoint<T>> {
37        Rc::new(SuspensionJoint { state: UnsafeCell::new(SuspensionState::Empty), wakers: Cell::new(1) })
38    }
39
40    fn is_ready(&self) -> bool {
41        let state = unsafe { &*self.state.get() };
42        matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_))
43    }
44
45    fn wake_coroutine(co: ptr::NonNull<Coroutine>) {
46        let task = unsafe { task::current().as_mut() };
47        task.resume(co);
48    }
49
50    fn add_waker(&self) {
51        let wakers = self.wakers.get() + 1;
52        self.wakers.set(wakers);
53    }
54
55    fn remove_waker(&self) {
56        let wakers = self.wakers.get() - 1;
57        self.wakers.set(wakers);
58        if wakers == 0 {
59            self.fault(PanicError::Static("suspend: no resumption"));
60        }
61    }
62
63    fn cancel(&self, err: PanicError) -> Option<ptr::NonNull<Coroutine>> {
64        let state = unsafe { &mut *self.state.get() };
65        if matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined) {
66            return None;
67        }
68        let state = unsafe { ptr::replace(state, SuspensionState::Panicked(err)) };
69        if let SuspensionState::Joining(co) = state {
70            return Some(co);
71        } else if let SuspensionState::Selector(selector) = state {
72            selector.apply(Permit::default());
73        }
74        None
75    }
76
77    fn fault(&self, err: PanicError) {
78        if let Some(co) = self.cancel(err) {
79            Self::wake_coroutine(co);
80        }
81    }
82
83    pub fn wake(&self, value: T) -> Result<(), T> {
84        let state = unsafe { &mut *self.state.get() };
85        if matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined) {
86            return Err(value);
87        }
88        let state = unsafe { ptr::replace(state, SuspensionState::Value(value)) };
89        if let SuspensionState::Joining(co) = state {
90            Self::wake_coroutine(co);
91        } else if let SuspensionState::Selector(selector) = state {
92            selector.apply(Permit::default());
93        }
94        Ok(())
95    }
96
97    fn set_result(&self, result: Result<T, Box<dyn Any + Send + 'static>>) {
98        match result {
99            Ok(value) => self.wake(value).ignore(),
100            Err(err) => self.fault(PanicError::Unwind(err)),
101        }
102    }
103
104    fn watch_permit(&self, selector: Selector) -> bool {
105        let state = unsafe { &mut *self.state.get() };
106        match state {
107            SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined => {
108                selector.apply(Permit::default());
109                return true;
110            },
111            SuspensionState::Joining(_) => unreachable!("suspension: joining state"),
112            SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
113            SuspensionState::Empty => unsafe { ptr::write(state, SuspensionState::Selector(selector)) },
114        }
115        false
116    }
117
118    fn unwatch_permit(&self, identifer: &Identifier) {
119        let state = unsafe { &mut *self.state.get() };
120        if let SuspensionState::Selector(selector) = state {
121            assert!(selector.identify(identifer), "suspension: selecting by other");
122            *state = SuspensionState::Empty;
123        }
124    }
125
126    fn consume_permit(&self) -> Result<T, PanicError> {
127        self.take()
128    }
129
130    fn take(&self) -> Result<T, PanicError> {
131        let state = mem::replace(unsafe { &mut *self.state.get() }, SuspensionState::Joined);
132        match state {
133            SuspensionState::Value(value) => Ok(value),
134            SuspensionState::Panicked(err) => Err(err),
135            SuspensionState::Empty => unreachable!("suspension: empty state"),
136            SuspensionState::Joining(_) => unreachable!("suspension: joining state"),
137            SuspensionState::Joined => unreachable!("suspension: joined state"),
138            SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
139        }
140    }
141
142    fn join(&self) -> Result<T, PanicError> {
143        let co = super::current();
144        let state = mem::replace(unsafe { &mut *self.state.get() }, SuspensionState::Joining(co));
145        match state {
146            SuspensionState::Empty => {
147                let task = unsafe { task::current().as_mut() };
148                task.suspend(co, self);
149                self.take()
150            },
151            SuspensionState::Value(value) => Ok(value),
152            SuspensionState::Panicked(err) => Err(err),
153            SuspensionState::Joining(_) => unreachable!("suspension: join joining state"),
154            SuspensionState::Joined => unreachable!("suspension: join joined state"),
155            SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
156        }
157    }
158}
159
160/// Suspension provides method to suspend calling coroutine.
161pub struct Suspension<T: 'static>(Rc<SuspensionJoint<T>>);
162
163/// Resumption provides method to resume suspending coroutine.
164pub struct Resumption<T: 'static> {
165    joint: Rc<SuspensionJoint<T>>,
166}
167
168assert_not_impl_any!(Suspension<()>: Send);
169assert_not_impl_any!(Resumption<()>: Send);
170
171impl<T> Suspension<T> {
172    unsafe fn into_joint(self) -> Rc<SuspensionJoint<T>> {
173        let joint = ptr::read(&self.0);
174        mem::forget(self);
175        joint
176    }
177
178    /// Checks readiness.
179    pub fn is_ready(&self) -> bool {
180        self.0.is_ready()
181    }
182
183    /// Suspends calling coroutine until [Resumption::resume].
184    ///
185    /// # Panics
186    /// Panic if no resume from [Resumption].
187    ///
188    /// # Guarantee
189    /// Only two situations can happen:
190    /// * This method panics and no value sent
191    /// * This method returns and only one value sent
192    ///
193    /// This means that no value linger after panic.
194    pub fn suspend(self) -> T {
195        let joint = unsafe { self.into_joint() };
196        match joint.join() {
197            Ok(value) => value,
198            Err(PanicError::Unwind(err)) => panic::resume_unwind(err),
199            Err(PanicError::Static(s)) => panic::panic_any(s),
200        }
201    }
202}
203
204impl<T> Drop for Suspension<T> {
205    fn drop(&mut self) {
206        self.0.cancel(PanicError::Static("suspension dropped"));
207    }
208}
209
210impl<T> Resumption<T> {
211    fn new(joint: Rc<SuspensionJoint<T>>) -> Self {
212        Resumption { joint }
213    }
214
215    // SAFETY: Forget self and cancel drop to wake peer with no suspension interleave.
216    unsafe fn into_joint(self) -> Rc<SuspensionJoint<T>> {
217        let joint = ptr::read(&self.joint);
218        mem::forget(self);
219        joint
220    }
221
222    /// Resumes suspending coroutine.
223    pub fn resume(self, value: T) -> bool {
224        let joint = unsafe { self.into_joint() };
225        joint.wake(value).is_ok()
226    }
227
228    /// Sends and wakes peer if not waked.
229    pub fn send(self, value: T) -> Result<(), T> {
230        let joint = unsafe { self.into_joint() };
231        joint.wake(value)
232    }
233
234    pub(super) fn set_result(self, result: Result<T, Box<dyn Any + Send + 'static>>) {
235        let joint = unsafe { self.into_joint() };
236        joint.set_result(result);
237    }
238}
239
240impl<T> Clone for Resumption<T> {
241    fn clone(&self) -> Self {
242        self.joint.add_waker();
243        Resumption { joint: self.joint.clone() }
244    }
245}
246
247impl<T> Drop for Resumption<T> {
248    fn drop(&mut self) {
249        self.joint.remove_waker();
250    }
251}
252
253/// Constructs cooperative facilities to suspend and resume coroutine in current task.
254pub fn suspension<T>() -> (Suspension<T>, Resumption<T>) {
255    let joint = SuspensionJoint::new();
256    let suspension = Suspension(joint.clone());
257    (suspension, Resumption::new(joint))
258}
259
260/// JoinHandle provides method to retrieve result of associated cooperative task.
261pub struct JoinHandle<T: 'static> {
262    suspension: Option<Suspension<T>>,
263}
264
265assert_not_impl_any!(JoinHandle<()>: Send);
266
267impl<T> JoinHandle<T> {
268    pub(super) fn new(suspension: Suspension<T>) -> Self {
269        JoinHandle { suspension: Some(suspension) }
270    }
271
272    /// Waits for associated coroutine to finish and returns its result.
273    ///
274    /// # Panics
275    /// * Panic if already joined by `select!`.
276    /// * Panic if main coroutine finished.
277    pub fn join(mut self) -> Result<T, JoinError> {
278        if let Some(suspension) = self.suspension.take() {
279            let joint = unsafe { suspension.into_joint() };
280            joint.join().map_err(JoinError::new)
281        } else {
282            panic!("already joined by select")
283        }
284    }
285}
286
287impl<T: 'static> Selectable for JoinHandle<T> {
288    fn parallel(&self) -> bool {
289        false
290    }
291
292    fn select_permit(&self) -> Result<Permit, TrySelectError> {
293        if let Some(suspension) = self.suspension.as_ref() {
294            if suspension.is_ready() {
295                Ok(Permit::default())
296            } else {
297                Err(TrySelectError::WouldBlock)
298            }
299        } else {
300            Err(TrySelectError::Completed)
301        }
302    }
303
304    fn watch_permit(&self, selector: Selector) -> bool {
305        if let Some(suspension) = self.suspension.as_ref() {
306            suspension.0.watch_permit(selector)
307        } else {
308            false
309        }
310    }
311
312    fn unwatch_permit(&self, identifier: &Identifier) {
313        if let Some(suspension) = self.suspension.as_ref() {
314            suspension.0.unwatch_permit(identifier);
315        }
316    }
317}
318
319impl<T: 'static> PermitReader for JoinHandle<T> {
320    type Result = Result<T, JoinError>;
321
322    fn consume_permit(&mut self, _permit: Permit) -> Result<T, JoinError> {
323        if let Some(suspension) = self.suspension.take() {
324            let joint = unsafe { suspension.into_joint() };
325            joint.consume_permit().map_err(JoinError::new)
326        } else {
327            panic!("JoinHandle: already consumed")
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use ignore_result::Ignore;
335
336    use crate::{coroutine, select};
337
338    #[crate::test(crate = "crate")]
339    fn resumption() {
340        let (suspension, resumption) = coroutine::suspension();
341        drop(resumption.clone());
342        assert_eq!(suspension.0.is_ready(), false);
343        let co1 = coroutine::spawn({
344            let resumption = resumption.clone();
345            move || resumption.send(5)
346        });
347        let co2 = coroutine::spawn(move || resumption.send(6));
348        let value = suspension.suspend();
349        let mut result1 = co1.join().unwrap();
350        let mut result2 = co2.join().unwrap();
351        if result1.is_err() {
352            std::mem::swap(&mut result1, &mut result2);
353        }
354        assert_eq!(result1, Ok(()));
355        assert_eq!(result2.is_err(), true);
356        assert_eq!(value, 11 - result2.unwrap_err());
357    }
358
359    #[crate::test(crate = "crate")]
360    fn suspension_dropped() {
361        let (suspension, resumption) = coroutine::suspension::<()>();
362        drop(suspension);
363        assert_eq!(resumption.joint.is_ready(), true);
364    }
365
366    #[crate::test(crate = "crate")]
367    #[should_panic(expected = "deadlock suspending coroutines")]
368    fn suspension_deadlock() {
369        let (suspension, resumption) = coroutine::suspension::<()>();
370        suspension.suspend();
371        drop(resumption);
372    }
373
374    #[crate::test(crate = "crate")]
375    fn join_handle_join() {
376        let join_handle = coroutine::spawn(|| 5);
377        assert_eq!(join_handle.join().unwrap(), 5);
378    }
379
380    #[crate::test(crate = "crate")]
381    fn join_handle_join_panic() {
382        const REASON: &'static str = "oooooops";
383        let co = coroutine::spawn(|| panic!("{}", REASON));
384        let err = co.join().unwrap_err();
385        assert!(err.to_string().contains(REASON))
386    }
387
388    #[crate::test(crate = "crate")]
389    fn join_handle_select() {
390        let mut join_handle = coroutine::spawn(|| 5);
391        select! {
392            r = <-join_handle => assert_eq!(r.unwrap(), 5),
393        }
394    }
395
396    #[crate::test(crate = "crate")]
397    fn join_handle_select_complete() {
398        let mut join_handle = coroutine::spawn(|| 5);
399        select! {
400            r = <-join_handle => assert_eq!(r.unwrap(), 5),
401        }
402        select! {
403            r = <-join_handle => assert_eq!(r.unwrap(), 5),
404            complete => {},
405        }
406    }
407
408    #[crate::test(crate = "crate")]
409    #[should_panic(expected = "already joined by select")]
410    fn join_handle_join_consumed() {
411        let mut join_handle = coroutine::spawn(|| 5);
412        select! {
413            r = <-join_handle => assert_eq!(r.unwrap(), 5),
414        }
415        join_handle.join().ignore();
416    }
417
418    #[crate::test(crate = "crate")]
419    #[should_panic(expected = "all selectables are disabled or completed")]
420    fn join_handle_select_consumed() {
421        let mut join_handle = coroutine::spawn(|| 5);
422        select! {
423            r = <-join_handle => assert_eq!(r.unwrap(), 5),
424        }
425        select! {
426            r = <-join_handle => assert_eq!(r.unwrap(), 5),
427        }
428    }
429}