1use std::any::Any;
2use std::cell::{Cell, UnsafeCell};
3use std::rc::Rc;
4use std::{mem, panic, ptr};
5
6use ignore_result::Ignore;
7use static_assertions::assert_not_impl_any;
8
9use super::Coroutine;
10use crate::error::{JoinError, PanicError};
11use crate::select::{Identifier, Permit, PermitReader, Selectable, Selector, TrySelectError};
12use crate::task::{self, Yielding};
13
14enum SuspensionState<T: 'static> {
15 Empty,
16 Value(T),
17 Panicked(PanicError),
18 Joining(ptr::NonNull<Coroutine>),
19 Selector(Selector),
20 Joined,
21}
22
23struct SuspensionJoint<T: 'static> {
24 state: UnsafeCell<SuspensionState<T>>,
25 wakers: Cell<usize>,
26}
27
28impl<T> Yielding for SuspensionJoint<T> {
29 fn interrupt(&self, reason: &'static str) -> bool {
30 self.cancel(PanicError::Static(reason));
31 true
32 }
33}
34
35impl<T> SuspensionJoint<T> {
36 fn new() -> Rc<SuspensionJoint<T>> {
37 Rc::new(SuspensionJoint { state: UnsafeCell::new(SuspensionState::Empty), wakers: Cell::new(1) })
38 }
39
40 fn is_ready(&self) -> bool {
41 let state = unsafe { &*self.state.get() };
42 matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_))
43 }
44
45 fn wake_coroutine(co: ptr::NonNull<Coroutine>) {
46 let task = unsafe { task::current().as_mut() };
47 task.resume(co);
48 }
49
50 fn add_waker(&self) {
51 let wakers = self.wakers.get() + 1;
52 self.wakers.set(wakers);
53 }
54
55 fn remove_waker(&self) {
56 let wakers = self.wakers.get() - 1;
57 self.wakers.set(wakers);
58 if wakers == 0 {
59 self.fault(PanicError::Static("suspend: no resumption"));
60 }
61 }
62
63 fn cancel(&self, err: PanicError) -> Option<ptr::NonNull<Coroutine>> {
64 let state = unsafe { &mut *self.state.get() };
65 if matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined) {
66 return None;
67 }
68 let state = unsafe { ptr::replace(state, SuspensionState::Panicked(err)) };
69 if let SuspensionState::Joining(co) = state {
70 return Some(co);
71 } else if let SuspensionState::Selector(selector) = state {
72 selector.apply(Permit::default());
73 }
74 None
75 }
76
77 fn fault(&self, err: PanicError) {
78 if let Some(co) = self.cancel(err) {
79 Self::wake_coroutine(co);
80 }
81 }
82
83 pub fn wake(&self, value: T) -> Result<(), T> {
84 let state = unsafe { &mut *self.state.get() };
85 if matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined) {
86 return Err(value);
87 }
88 let state = unsafe { ptr::replace(state, SuspensionState::Value(value)) };
89 if let SuspensionState::Joining(co) = state {
90 Self::wake_coroutine(co);
91 } else if let SuspensionState::Selector(selector) = state {
92 selector.apply(Permit::default());
93 }
94 Ok(())
95 }
96
97 fn set_result(&self, result: Result<T, Box<dyn Any + Send + 'static>>) {
98 match result {
99 Ok(value) => self.wake(value).ignore(),
100 Err(err) => self.fault(PanicError::Unwind(err)),
101 }
102 }
103
104 fn watch_permit(&self, selector: Selector) -> bool {
105 let state = unsafe { &mut *self.state.get() };
106 match state {
107 SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined => {
108 selector.apply(Permit::default());
109 return true;
110 },
111 SuspensionState::Joining(_) => unreachable!("suspension: joining state"),
112 SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
113 SuspensionState::Empty => unsafe { ptr::write(state, SuspensionState::Selector(selector)) },
114 }
115 false
116 }
117
118 fn unwatch_permit(&self, identifer: &Identifier) {
119 let state = unsafe { &mut *self.state.get() };
120 if let SuspensionState::Selector(selector) = state {
121 assert!(selector.identify(identifer), "suspension: selecting by other");
122 *state = SuspensionState::Empty;
123 }
124 }
125
126 fn consume_permit(&self) -> Result<T, PanicError> {
127 self.take()
128 }
129
130 fn take(&self) -> Result<T, PanicError> {
131 let state = mem::replace(unsafe { &mut *self.state.get() }, SuspensionState::Joined);
132 match state {
133 SuspensionState::Value(value) => Ok(value),
134 SuspensionState::Panicked(err) => Err(err),
135 SuspensionState::Empty => unreachable!("suspension: empty state"),
136 SuspensionState::Joining(_) => unreachable!("suspension: joining state"),
137 SuspensionState::Joined => unreachable!("suspension: joined state"),
138 SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
139 }
140 }
141
142 fn join(&self) -> Result<T, PanicError> {
143 let co = super::current();
144 let state = mem::replace(unsafe { &mut *self.state.get() }, SuspensionState::Joining(co));
145 match state {
146 SuspensionState::Empty => {
147 let task = unsafe { task::current().as_mut() };
148 task.suspend(co, self);
149 self.take()
150 },
151 SuspensionState::Value(value) => Ok(value),
152 SuspensionState::Panicked(err) => Err(err),
153 SuspensionState::Joining(_) => unreachable!("suspension: join joining state"),
154 SuspensionState::Joined => unreachable!("suspension: join joined state"),
155 SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
156 }
157 }
158}
159
160pub struct Suspension<T: 'static>(Rc<SuspensionJoint<T>>);
162
163pub struct Resumption<T: 'static> {
165 joint: Rc<SuspensionJoint<T>>,
166}
167
168assert_not_impl_any!(Suspension<()>: Send);
169assert_not_impl_any!(Resumption<()>: Send);
170
171impl<T> Suspension<T> {
172 unsafe fn into_joint(self) -> Rc<SuspensionJoint<T>> {
173 let joint = ptr::read(&self.0);
174 mem::forget(self);
175 joint
176 }
177
178 pub fn is_ready(&self) -> bool {
180 self.0.is_ready()
181 }
182
183 pub fn suspend(self) -> T {
195 let joint = unsafe { self.into_joint() };
196 match joint.join() {
197 Ok(value) => value,
198 Err(PanicError::Unwind(err)) => panic::resume_unwind(err),
199 Err(PanicError::Static(s)) => panic::panic_any(s),
200 }
201 }
202}
203
204impl<T> Drop for Suspension<T> {
205 fn drop(&mut self) {
206 self.0.cancel(PanicError::Static("suspension dropped"));
207 }
208}
209
210impl<T> Resumption<T> {
211 fn new(joint: Rc<SuspensionJoint<T>>) -> Self {
212 Resumption { joint }
213 }
214
215 unsafe fn into_joint(self) -> Rc<SuspensionJoint<T>> {
217 let joint = ptr::read(&self.joint);
218 mem::forget(self);
219 joint
220 }
221
222 pub fn resume(self, value: T) -> bool {
224 let joint = unsafe { self.into_joint() };
225 joint.wake(value).is_ok()
226 }
227
228 pub fn send(self, value: T) -> Result<(), T> {
230 let joint = unsafe { self.into_joint() };
231 joint.wake(value)
232 }
233
234 pub(super) fn set_result(self, result: Result<T, Box<dyn Any + Send + 'static>>) {
235 let joint = unsafe { self.into_joint() };
236 joint.set_result(result);
237 }
238}
239
240impl<T> Clone for Resumption<T> {
241 fn clone(&self) -> Self {
242 self.joint.add_waker();
243 Resumption { joint: self.joint.clone() }
244 }
245}
246
247impl<T> Drop for Resumption<T> {
248 fn drop(&mut self) {
249 self.joint.remove_waker();
250 }
251}
252
253pub fn suspension<T>() -> (Suspension<T>, Resumption<T>) {
255 let joint = SuspensionJoint::new();
256 let suspension = Suspension(joint.clone());
257 (suspension, Resumption::new(joint))
258}
259
260pub struct JoinHandle<T: 'static> {
262 suspension: Option<Suspension<T>>,
263}
264
265assert_not_impl_any!(JoinHandle<()>: Send);
266
267impl<T> JoinHandle<T> {
268 pub(super) fn new(suspension: Suspension<T>) -> Self {
269 JoinHandle { suspension: Some(suspension) }
270 }
271
272 pub fn join(mut self) -> Result<T, JoinError> {
278 if let Some(suspension) = self.suspension.take() {
279 let joint = unsafe { suspension.into_joint() };
280 joint.join().map_err(JoinError::new)
281 } else {
282 panic!("already joined by select")
283 }
284 }
285}
286
287impl<T: 'static> Selectable for JoinHandle<T> {
288 fn parallel(&self) -> bool {
289 false
290 }
291
292 fn select_permit(&self) -> Result<Permit, TrySelectError> {
293 if let Some(suspension) = self.suspension.as_ref() {
294 if suspension.is_ready() {
295 Ok(Permit::default())
296 } else {
297 Err(TrySelectError::WouldBlock)
298 }
299 } else {
300 Err(TrySelectError::Completed)
301 }
302 }
303
304 fn watch_permit(&self, selector: Selector) -> bool {
305 if let Some(suspension) = self.suspension.as_ref() {
306 suspension.0.watch_permit(selector)
307 } else {
308 false
309 }
310 }
311
312 fn unwatch_permit(&self, identifier: &Identifier) {
313 if let Some(suspension) = self.suspension.as_ref() {
314 suspension.0.unwatch_permit(identifier);
315 }
316 }
317}
318
319impl<T: 'static> PermitReader for JoinHandle<T> {
320 type Result = Result<T, JoinError>;
321
322 fn consume_permit(&mut self, _permit: Permit) -> Result<T, JoinError> {
323 if let Some(suspension) = self.suspension.take() {
324 let joint = unsafe { suspension.into_joint() };
325 joint.consume_permit().map_err(JoinError::new)
326 } else {
327 panic!("JoinHandle: already consumed")
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use ignore_result::Ignore;
335
336 use crate::{coroutine, select};
337
338 #[crate::test(crate = "crate")]
339 fn resumption() {
340 let (suspension, resumption) = coroutine::suspension();
341 drop(resumption.clone());
342 assert_eq!(suspension.0.is_ready(), false);
343 let co1 = coroutine::spawn({
344 let resumption = resumption.clone();
345 move || resumption.send(5)
346 });
347 let co2 = coroutine::spawn(move || resumption.send(6));
348 let value = suspension.suspend();
349 let mut result1 = co1.join().unwrap();
350 let mut result2 = co2.join().unwrap();
351 if result1.is_err() {
352 std::mem::swap(&mut result1, &mut result2);
353 }
354 assert_eq!(result1, Ok(()));
355 assert_eq!(result2.is_err(), true);
356 assert_eq!(value, 11 - result2.unwrap_err());
357 }
358
359 #[crate::test(crate = "crate")]
360 fn suspension_dropped() {
361 let (suspension, resumption) = coroutine::suspension::<()>();
362 drop(suspension);
363 assert_eq!(resumption.joint.is_ready(), true);
364 }
365
366 #[crate::test(crate = "crate")]
367 #[should_panic(expected = "deadlock suspending coroutines")]
368 fn suspension_deadlock() {
369 let (suspension, resumption) = coroutine::suspension::<()>();
370 suspension.suspend();
371 drop(resumption);
372 }
373
374 #[crate::test(crate = "crate")]
375 fn join_handle_join() {
376 let join_handle = coroutine::spawn(|| 5);
377 assert_eq!(join_handle.join().unwrap(), 5);
378 }
379
380 #[crate::test(crate = "crate")]
381 fn join_handle_join_panic() {
382 const REASON: &'static str = "oooooops";
383 let co = coroutine::spawn(|| panic!("{}", REASON));
384 let err = co.join().unwrap_err();
385 assert!(err.to_string().contains(REASON))
386 }
387
388 #[crate::test(crate = "crate")]
389 fn join_handle_select() {
390 let mut join_handle = coroutine::spawn(|| 5);
391 select! {
392 r = <-join_handle => assert_eq!(r.unwrap(), 5),
393 }
394 }
395
396 #[crate::test(crate = "crate")]
397 fn join_handle_select_complete() {
398 let mut join_handle = coroutine::spawn(|| 5);
399 select! {
400 r = <-join_handle => assert_eq!(r.unwrap(), 5),
401 }
402 select! {
403 r = <-join_handle => assert_eq!(r.unwrap(), 5),
404 complete => {},
405 }
406 }
407
408 #[crate::test(crate = "crate")]
409 #[should_panic(expected = "already joined by select")]
410 fn join_handle_join_consumed() {
411 let mut join_handle = coroutine::spawn(|| 5);
412 select! {
413 r = <-join_handle => assert_eq!(r.unwrap(), 5),
414 }
415 join_handle.join().ignore();
416 }
417
418 #[crate::test(crate = "crate")]
419 #[should_panic(expected = "all selectables are disabled or completed")]
420 fn join_handle_select_consumed() {
421 let mut join_handle = coroutine::spawn(|| 5);
422 select! {
423 r = <-join_handle => assert_eq!(r.unwrap(), 5),
424 }
425 select! {
426 r = <-join_handle => assert_eq!(r.unwrap(), 5),
427 }
428 }
429}