1#![doc = include_str!("../README.md")]
2#![allow(unsafe_op_in_unsafe_fn)]
3
4mod blocking;
5mod error;
6mod id;
7mod join;
8mod raw;
9mod state;
10mod waker;
11
12use crate::raw::*;
13
14pub use blocking::BlockingTask;
15pub use error::JoinError;
16pub use id::Id;
17pub use join::JoinHandle;
18
19use state::*;
20use std::{
21 cell::UnsafeCell,
22 fmt,
23 future::Future,
24 marker::PhantomData,
25 mem::ManuallyDrop,
26 panic::{AssertUnwindSafe, catch_unwind},
27 pin::Pin,
28 sync::Arc,
29 task::{Context, Poll, Wake, Waker},
30};
31
32pub trait Scheduler<M>: 'static {
33 fn schedule(&self, task: Task<M>);
34}
35
36impl<F, M> Scheduler<M> for F
37where
38 F: Fn(Task<M>) + 'static,
39{
40 fn schedule(&self, runnable: Task<M>) {
41 self(runnable)
42 }
43}
44
45struct RawTaskInner<F: Future, S: Scheduler<M>, M> {
46 header: Header,
47 future: UnsafeCell<Fut<F, F::Output>>,
48 meta: UnsafeCell<M>,
49 scheduler: S,
50}
51
52unsafe impl<F: Future, S: Scheduler<M>, M> Send for RawTaskInner<F, S, M> {}
53unsafe impl<F: Future, S: Scheduler<M>, M> Sync for RawTaskInner<F, S, M> {}
54
55pub struct Task<M = ()> {
56 raw: RawTask,
57 _meta: PhantomData<M>,
58}
59
60unsafe impl<M> Send for Task<M> {}
61unsafe impl<M> Sync for Task<M> {}
62
63impl<M> std::panic::UnwindSafe for Task<M> {}
64impl<M> std::panic::RefUnwindSafe for Task<M> {}
65
66pub struct Metadata<M>(Task<M>);
67
68impl<M> Metadata<M> {
69 pub fn get(&self) -> &M {
70 self.0.metadata()
71 }
72
73 pub fn get_mut(&mut self) -> &mut M {
74 self.0.metadata_mut()
75 }
76}
77
78pub enum Status<M> {
79 Yielded(Task<M>),
80 Pending,
81 Complete(Metadata<M>),
82}
83
84impl Task {
85 pub fn new<F, S>(future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
86 where
87 F: Future + Send + 'static,
88 F::Output: Send,
89 S: Scheduler<()>,
90 {
91 Self::new_with((), future, scheduler)
92 }
93
94 pub fn new_local<F, S>(future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
95 where
96 F: Future + 'static,
97 F::Output: 'static,
98 S: Scheduler<()>,
99 {
100 Self::new_local_with((), future, scheduler)
101 }
102}
103
104impl<M> Task<M> {
105 pub fn metadata(&self) -> &M {
106 unsafe { &*self.raw.metadata().cast::<M>() }
107 }
108
109 pub fn metadata_mut(&mut self) -> &mut M {
110 unsafe { &mut *self.raw.metadata().cast::<M>() }
111 }
112
113 pub fn new_with<F, S>(meta: M, future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
114 where
115 M: 'static + Send,
116 F: Future + Send + 'static,
117 F::Output: Send,
118 S: Scheduler<M>,
119 {
120 let raw = Arc::new(RawTaskInner {
121 header: Header::new(),
122 future: UnsafeCell::new(Fut::Future(future)),
123 meta: UnsafeCell::new(meta),
124 scheduler,
125 });
126 let join_handle = JoinHandle::new(raw.clone());
127 (
128 Self {
129 raw,
130 _meta: PhantomData,
131 },
132 join_handle,
133 )
134 }
135
136 pub fn new_local_with<F, S>(meta: M, future: F, scheduler: S) -> (Self, JoinHandle<F::Output>)
137 where
138 M: 'static + Send,
139 F: Future + 'static,
140 F::Output: 'static,
141 S: Scheduler<M>,
142 {
143 let raw = Arc::new(RawTaskInner {
144 header: Header::new(),
145 future: UnsafeCell::new(Fut::Future(future)),
146 meta: UnsafeCell::new(meta),
147 scheduler,
148 });
149 let join_handle = JoinHandle::new(raw.clone());
150 (
151 Self {
152 raw,
153 _meta: PhantomData,
154 },
155 join_handle,
156 )
157 }
158
159 #[inline]
160 pub fn poll(self) -> Status<M> {
161 let raw = unsafe { Arc::from_raw(Arc::as_ptr(&self.raw)) };
163 let waker = ManuallyDrop::new(raw.waker());
165
166 match unsafe { self.raw.poll(&waker) } {
168 PollStatus::Yield => Status::Yielded(self),
169 PollStatus::Pending => Status::Pending,
170 PollStatus::Complete => Status::Complete(Metadata(self)),
171 }
172 }
173
174 #[inline]
175 pub fn schedule(self) {
176 unsafe { self.raw.schedule() }
177 }
178
179 #[inline]
180 pub fn id(&self) -> Id {
181 Id::new(&self.raw)
182 }
183}
184
185impl<F, S, M> RawTaskVTable for RawTaskInner<F, S, M>
186where
187 M: 'static,
188 F: Future + 'static,
189 S: Scheduler<M>,
190{
191 #[inline]
192 fn waker(self: Arc<Self>) -> Waker {
193 Waker::from(self)
194 }
195
196 #[inline]
197 fn header(&self) -> &Header {
198 &self.header
199 }
200
201 unsafe fn metadata(&self) -> *mut () {
202 self.meta.get().cast()
203 }
204
205 unsafe fn poll(&self, waker: &Waker) -> PollStatus {
211 let is_cancelled = self.header.transition_to_running_and_check_if_cancelled();
212
213 let has_output = catch_unwind(AssertUnwindSafe(|| {
214 let poll_result = unsafe {
215 let fut = match &mut *self.future.get() {
216 Fut::Future(fut) => Pin::new_unchecked(fut),
217 _ => unreachable!(),
218 };
219 fut.poll(&mut Context::from_waker(waker))
221 };
222 let result = match poll_result {
223 Poll::Ready(val) => Ok(val),
224 Poll::Pending if is_cancelled => Err(JoinError::cancelled()),
225 Poll::Pending => return false,
226 };
227 unsafe {
229 (*self.future.get()).set_output(result);
230 }
231 true
232 }));
233
234 match has_output {
235 Ok(false) => return self.header.transition_to_sleep(),
236 Ok(true) => {}
237 Err(err) => unsafe { (*self.future.get()).set_output(Err(JoinError::panic(err))) },
238 }
239 if !self
240 .header
241 .transition_to_complete_and_notify_output_if_intrested()
242 {
243 let _ = catch_unwind(AssertUnwindSafe(|| unsafe { (*self.future.get()).drop() }));
246 }
247 PollStatus::Complete
248 }
249
250 unsafe fn schedule(self: Arc<Self>) {
251 self.scheduler.schedule(Task {
252 raw: self.clone(),
253 _meta: PhantomData,
254 });
255 }
256
257 unsafe fn abort_task(self: Arc<Self>) {
258 if self.header.transition_to_abort() {
259 self.schedule()
260 }
261 }
262
263 unsafe fn read_output(&self, dst: *mut (), waker: &Waker) {
264 if self.header.can_read_output_or_notify_when_readable(waker) {
265 *(dst as *mut _) = Poll::Ready((*self.future.get()).take_output());
266 }
267 }
268
269 unsafe fn drop_join_handler(&self) {
270 let is_task_complete = self.header.state.unset_waker_and_interested();
271 if is_task_complete {
272 let _ = catch_unwind(AssertUnwindSafe(|| unsafe {
275 (*self.future.get()).drop();
276 }));
277 } else {
278 *self.header.join_waker.get() = None;
279 }
280 }
281}
282
283impl<F, S, M> RawTaskInner<F, S, M>
284where
285 M: 'static,
286 F: Future + 'static,
287 S: Scheduler<M>,
288{
289 unsafe fn schedule_by_ref(self: &Arc<Self>) {
290 self.scheduler.schedule(Task {
291 raw: self.clone(),
292 _meta: PhantomData,
293 });
294 }
295}
296
297impl<F, S, M> Wake for RawTaskInner<F, S, M>
298where
299 M: 'static,
300 F: Future + 'static,
301 S: Scheduler<M>,
302{
303 fn wake(self: Arc<Self>) {
304 unsafe {
305 if self.header.transition_to_notified() {
306 self.schedule();
307 }
308 }
309 }
310
311 fn wake_by_ref(self: &Arc<Self>) {
312 unsafe {
313 if self.header.transition_to_notified() {
314 self.schedule_by_ref();
315 }
316 }
317 }
318}
319
320impl<M> fmt::Debug for Task<M> {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 f.debug_struct("Task")
323 .field("id", &self.id())
324 .field("state", &self.raw.header().state.load())
325 .finish()
326 }
327}