1use crate::raw::{Fut, Header, PollStatus, RawTask, RawTaskVTable};
2use crate::waker::NOOP_WAKER;
3use crate::{Id, JoinHandle};
4
5use std::panic::AssertUnwindSafe;
6use std::task::{Poll, Waker};
7use std::{cell::UnsafeCell, sync::Arc};
8use std::{fmt, panic};
9
10pub struct BlockingTask {
11 raw: RawTask,
12}
13
14unsafe impl Send for BlockingTask {}
15unsafe impl Sync for BlockingTask {}
16
17impl BlockingTask {
18 pub fn new<F, T>(f: F) -> (BlockingTask, JoinHandle<T>)
19 where
20 F: FnOnce() -> T + Send + 'static,
21 T: Send + 'static,
22 {
23 let raw = Arc::new(BlockingRawTask {
24 header: Header::new(),
25 func: UnsafeCell::new(Fut::Future(f)),
26 });
27 let join = JoinHandle::new(raw.clone());
28 (BlockingTask { raw }, join)
29 }
30
31 pub fn run(self) {
32 unsafe { self.raw.poll(&NOOP_WAKER) };
33 }
34
35 #[inline]
36 pub fn id(&self) -> Id {
37 Id::new(&self.raw)
38 }
39}
40
41pub struct BlockingRawTask<F, T> {
42 header: Header,
43 func: UnsafeCell<Fut<F, T>>,
44}
45
46unsafe impl<F: Send, T> Sync for BlockingRawTask<F, T> {}
47
48impl<F, T> RawTaskVTable for BlockingRawTask<F, T>
49where
50 F: FnOnce() -> T + Send + 'static,
51 T: Send + 'static,
52{
53 fn header(&self) -> &Header {
54 &self.header
55 }
56
57 fn waker(self: Arc<Self>) -> Waker {
58 NOOP_WAKER
59 }
60
61 unsafe fn metadata(&self) -> *mut () {
62 std::ptr::null_mut()
63 }
64
65 unsafe fn poll(&self, _: &Waker) -> PollStatus {
67 let output = match (*self.func.get()).take() {
68 Fut::Future(func) => func(),
69 _ => unreachable!(),
70 };
71 (*self.func.get()).set_output(Ok(output));
72 if !self
73 .header
74 .transition_to_complete_and_notify_output_if_intrested()
75 {
76 unsafe {
77 (*self.func.get()).drop();
78 };
79 }
80 PollStatus::Complete
81 }
82
83 unsafe fn read_output(&self, dst: *mut (), waker: &Waker) {
84 if self.header.can_read_output_or_notify_when_readable(waker) {
85 *(dst as *mut _) = Poll::Ready((*self.func.get()).take_output());
86 }
87 }
88
89 unsafe fn drop_join_handler(&self) {
90 let is_task_complete = self.header.state.unset_waker_and_interested();
91 if is_task_complete {
92 let _ = panic::catch_unwind(AssertUnwindSafe(|| unsafe {
95 (*self.func.get()).drop();
96 }));
97 } else {
98 *self.header.join_waker.get() = None;
99 }
100 }
101
102 unsafe fn abort_task(self: Arc<Self>) {}
103 unsafe fn schedule(self: Arc<Self>) {}
104}
105
106impl fmt::Debug for BlockingTask {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("BlockingTask")
109 .field("id", &self.id())
110 .field("state", &self.raw.header().state.load())
111 .finish()
112 }
113}