1use crate::raw::{Fut, Header, PollStatus, RawTask, RawTaskHeader, RawTaskVTable};
2use crate::thin_arc::ThinArc;
3use crate::{JoinError, JoinHandle, TaskId};
4
5use std::cell::UnsafeCell;
6use std::fmt;
7use std::panic::{AssertUnwindSafe, catch_unwind};
8use std::task::{Poll, Waker};
9
10pub struct BlockingTask {
11 raw: RawTask,
12}
13
14unsafe impl Send for BlockingTask {}
15unsafe impl Sync for BlockingTask {}
16
17impl BlockingTask {
18 pub fn new<F, T>(f: F) -> (BlockingTask, JoinHandle<T>)
19 where
20 F: FnOnce() -> T + Send + 'static,
21 T: Send + 'static,
22 {
23 let (raw, join) = ThinArc::new(Box::new(RawTaskHeader {
24 header: Header::new(),
25 data: BlockingRawTask {
26 func: UnsafeCell::new(Fut::Future(f)),
27 },
28 }));
29 (BlockingTask { raw }, JoinHandle::new(join))
30 }
31
32 pub fn run(self) {
33 unsafe { self.raw.poll(Waker::noop()) };
34 }
35
36 #[inline]
37 pub fn id(&self) -> TaskId {
38 TaskId::new(&self.raw)
39 }
40}
41
42pub struct BlockingRawTask<F, T> {
43 func: UnsafeCell<Fut<F, T>>,
44}
45
46impl<F, T> RawTaskVTable for RawTaskHeader<BlockingRawTask<F, T>>
47where
48 F: FnOnce() -> T + Send + 'static,
49 T: Send + 'static,
50{
51 fn waker(&self, _: RawTask) -> Waker {
52 unreachable!()
53 }
54
55 unsafe fn metadata(&self) -> *mut () {
56 std::ptr::null_mut()
57 }
58
59 unsafe fn poll(&self, _: &Waker) -> PollStatus {
60 let maybe_panicked = catch_unwind(AssertUnwindSafe(|| {
61 let output = match (*self.data.func.get()).take() {
62 Fut::Future(func) => func(), _ => unreachable!(),
64 };
65 (*self.data.func.get()).set_output(Ok(output));
67 }));
68
69 if let Err(err) = maybe_panicked {
70 (*self.data.func.get()).set_output(Err(JoinError::panic(err)));
71 }
72
73 if !self
74 .header
75 .transition_to_complete_and_notify_output_if_intrested()
76 {
77 unsafe { (*self.data.func.get()).drop() };
80 }
81 PollStatus::Complete
82 }
83
84 unsafe fn read_output(&self, dst: *mut (), waker: &Waker) {
85 if self.header.can_read_output_or_notify_when_readable(waker) {
86 *(dst as *mut _) = Poll::Ready((*self.data.func.get()).take_output());
87 }
88 }
89
90 unsafe fn drop_output_from_join_handler(&self) {
91 (*self.data.func.get()).drop();
92 }
93
94 unsafe fn schedule(&self, _: RawTask) {}
95 unsafe fn drop_task(&self) {}
96}
97
98impl fmt::Debug for BlockingTask {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.debug_struct("BlockingTask")
101 .field("id", &self.id())
102 .field("state", &self.raw.header().state.load())
103 .finish()
104 }
105}