1use crate::raw::{Fut, Header, PollStatus, RawTask, RawTaskHeader, RawTaskVTable};
2use crate::thin_arc::ThinArc;
3use crate::{JoinError, JoinHandle, TaskId};
4
5use std::cell::UnsafeCell;
6use std::fmt;
7use std::panic::{AssertUnwindSafe, catch_unwind};
8use std::task::{Poll, Waker};
9
10pub struct BlockingTask {
11 raw: RawTask,
12}
13
14unsafe impl Send for BlockingTask {}
15unsafe impl Sync for BlockingTask {}
16
17impl BlockingTask {
18 pub fn new<F, T>(f: F) -> (BlockingTask, JoinHandle<T>)
19 where
20 F: FnOnce() -> T + Send + 'static,
21 T: Send + 'static,
22 {
23 let (raw, join) = unsafe {
24 ThinArc::new(Box::new(RawTaskHeader {
25 header: Header::new(),
26 data: BlockingRawTask {
27 func: UnsafeCell::new(Fut::Future(f)),
28 },
29 }))
30 };
31 (BlockingTask { raw }, JoinHandle::new(join))
32 }
33
34 pub fn run(self) {
35 unsafe { self.raw.poll(Waker::noop()) };
36 }
37
38 #[inline]
39 pub fn id(&self) -> TaskId {
40 TaskId::new(&self.raw)
41 }
42}
43
44pub struct BlockingRawTask<F, T> {
45 func: UnsafeCell<Fut<F, T>>,
46}
47
48impl<F, T> RawTaskVTable for RawTaskHeader<BlockingRawTask<F, T>>
49where
50 F: FnOnce() -> T + Send + 'static,
51 T: Send + 'static,
52{
53 fn waker(&self, _: RawTask) -> Waker {
54 unreachable!()
55 }
56
57 unsafe fn metadata(&self) -> *mut () {
58 std::ptr::null_mut()
59 }
60
61 unsafe fn poll(&self, _: &Waker) -> PollStatus {
62 let maybe_panicked = catch_unwind(AssertUnwindSafe(|| {
63 let output = match (*self.data.func.get()).take() {
64 Fut::Future(func) => func(), _ => unreachable!(),
66 };
67 (*self.data.func.get()).set_output(Ok(output));
69 }));
70
71 if let Err(err) = maybe_panicked {
72 (*self.data.func.get()).set_output(Err(JoinError::panic(err)));
73 }
74
75 if !self
76 .header
77 .transition_to_complete_and_notify_output_if_intrested()
78 {
79 unsafe { (*self.data.func.get()).drop() };
82 }
83 PollStatus::Complete
84 }
85
86 unsafe fn read_output(&self, dst: *mut (), waker: &Waker) {
87 if self.header.can_read_output_or_notify_when_readable(waker) {
88 *(dst as *mut _) = Poll::Ready((*self.data.func.get()).take_output());
89 }
90 }
91
92 unsafe fn drop_output_from_join_handler(&self) {
93 (*self.data.func.get()).drop();
94 }
95
96 unsafe fn schedule(&self, _: RawTask) {}
97 unsafe fn drop_task(&self) {}
98}
99
100impl fmt::Debug for BlockingTask {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_struct("BlockingTask")
103 .field("id", &self.id())
104 .field("state", &self.raw.header().state.load())
105 .finish()
106 }
107}