1#![doc = include_str!("../README.md")]
2#![allow(unsafe_op_in_unsafe_fn)]
3
4mod abort;
5mod blocking;
6mod error;
7mod id;
8mod join;
9mod raw;
10mod state;
11mod task;
12mod thin_arc;
13mod waker;
14
15use crate::{raw::*, thin_arc::ThinArc};
16
17pub use abort::AbortHandle;
18pub use blocking::BlockingTask;
19pub use error::JoinError;
20pub use id::{TaskId, id};
21pub use join::JoinHandle;
22
23use state::*;
24use std::{
25 cell::UnsafeCell,
26 fmt::{self, Debug},
27 future::Future,
28 marker::PhantomData,
29 mem::ManuallyDrop,
30};
31
32pub trait Scheduler<M = ()>: 'static {
33 fn schedule(&self, task: Task<M>);
34}
35
36impl<F, M> Scheduler<M> for F
37where
38 F: Fn(Task<M>) + 'static,
39{
40 fn schedule(&self, runnable: Task<M>) {
41 self(runnable)
42 }
43}
44
45pub struct Task<M = ()> {
46 raw: Option<RawTask>,
47 _meta: PhantomData<M>,
48}
49
50unsafe impl<M> Send for Task<M> {}
51unsafe impl<M> Sync for Task<M> {}
52
53impl<M> std::panic::UnwindSafe for Task<M> {}
54impl<M> std::panic::RefUnwindSafe for Task<M> {}
55
56impl<M> Drop for Task<M> {
57 fn drop(&mut self) {
58 if let Some(raw) = self.raw.take() {
59 unsafe { raw.drop_task() };
60 }
61 }
62}
63
64pub struct Metadata<M = ()> {
65 raw: RawTask,
66 _meta: PhantomData<M>,
67}
68
69impl<M: Debug> Debug for Metadata<M> {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 self.get().fmt(f)
72 }
73}
74
75impl<M> Metadata<M> {
76 pub fn id(&self) -> TaskId {
77 TaskId::new(&self.raw)
78 }
79
80 pub fn get(&self) -> &M {
81 unsafe { &*self.raw.metadata().cast::<M>() }
82 }
83
84 pub fn get_mut(&mut self) -> &mut M {
85 unsafe { &mut *self.raw.metadata().cast::<M>() }
86 }
87}
88
89#[derive(Debug)]
90pub enum Status<M> {
91 Yielded(Task<M>),
92 Pending,
93 Complete(Metadata<M>),
94}
95
96impl Task {
97 pub fn new<F, S>(future: F, scheduler: S) -> (Task, JoinHandle<F::Output>)
98 where
99 S: Scheduler<()> + Send,
100 F: Future + Send + 'static,
101 F::Output: Send + 'static,
102 {
103 unsafe { Self::new_unchecked((), future, scheduler) }
104 }
105
106 pub fn new_local<F, S>(future: F, scheduler: S) -> (Task, JoinHandle<F::Output>)
107 where
108 S: Scheduler<()> + Send,
109 F: Future + 'static,
110 F::Output: 'static,
111 {
112 Self::new_local_with((), future, scheduler)
113 }
114}
115
116impl<M> Task<M> {
117 pub(crate) fn from_raw(raw: RawTask) -> Self {
118 Self {
119 raw: Some(raw),
120 _meta: PhantomData,
121 }
122 }
123
124 pub fn metadata(&self) -> &M {
125 unsafe { &*self.raw.as_ref().unwrap_unchecked().metadata().cast() }
126 }
127
128 pub fn metadata_mut(&mut self) -> &mut M {
129 unsafe { &mut *self.raw.as_ref().unwrap_unchecked().metadata().cast() }
130 }
131
132 pub unsafe fn new_unchecked<F, S>(
133 meta: M,
134 future: F,
135 scheduler: S,
136 ) -> (Task<M>, JoinHandle<F::Output>)
137 where
138 M: 'static,
139 S: Scheduler<M>,
140 F: Future + 'static,
141 {
142 let (raw, join) = ThinArc::new(Box::new(RawTaskHeader {
143 header: Header::new(),
144 data: task::RawTaskInner {
145 future: UnsafeCell::new(Fut::Future(future)),
146 meta: UnsafeCell::new(meta),
147 scheduler,
148 },
149 }));
150 (Task::from_raw(raw), JoinHandle::new(join))
151 }
152
153 pub fn new_with<F, S>(meta: M, future: F, scheduler: S) -> (Task<M>, JoinHandle<F::Output>)
154 where
155 M: 'static + Send,
156 S: Scheduler<M> + Send,
157 F: Future + Send + 'static,
158 F::Output: Send + 'static,
159 {
160 unsafe { Self::new_unchecked(meta, future, scheduler) }
161 }
162
163 pub fn new_local_with<F, S>(
164 meta: M,
165 future: F,
166 scheduler: S,
167 ) -> (Task<M>, JoinHandle<F::Output>)
168 where
169 M: 'static + Send,
170 S: Scheduler<M> + Send,
171 F: Future + 'static,
172 F::Output: 'static,
173 {
174 use std::{
175 mem::ManuallyDrop,
176 pin::Pin,
177 task::{Context, Poll},
178 thread::{self, ThreadId},
179 };
180
181 #[inline]
182 fn thread_id() -> ThreadId {
183 std::thread_local! {
184 static ID: ThreadId = thread::current().id();
185 }
186 ID.try_with(|id| *id)
187 .unwrap_or_else(|_| thread::current().id())
188 }
189
190 struct Checked<F> {
191 id: ThreadId,
192 inner: ManuallyDrop<F>,
193 }
194
195 impl<F> Drop for Checked<F> {
196 fn drop(&mut self) {
197 assert!(
198 self.id == thread_id(),
199 "local task dropped by a thread that didn't spawn it"
200 );
201 unsafe {
202 ManuallyDrop::drop(&mut self.inner);
203 }
204 }
205 }
206
207 impl<F: Future> Future for Checked<F> {
208 type Output = F::Output;
209 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210 unsafe {
211 let me = self.get_unchecked_mut();
212 assert!(
213 me.id == thread_id(),
214 "local task polled by a thread that didn't spawn it"
215 );
216 Pin::new_unchecked(&mut *me.inner).poll(cx)
217 }
218 }
219 }
220
221 let future = Checked {
222 id: thread_id(),
223 inner: ManuallyDrop::new(future),
224 };
225
226 unsafe { Self::new_unchecked(meta, future, scheduler) }
227 }
228
229 #[inline]
230 pub fn poll(mut self) -> Status<M> {
231 let raw = unsafe { self.raw.take().unwrap_unchecked() };
232 let waker = raw.clone_without_ref_inc();
234 let waker = ManuallyDrop::new(raw.waker(waker));
236
237 match unsafe { raw.poll(&waker) } {
239 PollStatus::Yield => Status::Yielded(Task::from_raw(raw)),
240 PollStatus::Pending => Status::Pending,
241 PollStatus::Complete => Status::Complete(Metadata {
242 raw,
243 _meta: PhantomData,
244 }),
245 }
246 }
247
248 #[inline]
249 pub fn schedule(mut self) {
250 unsafe {
251 let raw = self.raw.take().unwrap_unchecked();
252 raw.schedule(raw.clone());
253 }
254 }
255
256 #[inline]
257 pub fn id(&self) -> TaskId {
258 TaskId::new(unsafe { self.raw.as_ref().unwrap_unchecked() })
259 }
260}
261
262impl<M: Debug> fmt::Debug for Task<M> {
263 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264 f.debug_struct("Task")
265 .field("id", &self.id())
266 .field("state", &self.raw.as_ref().unwrap().header().state.load())
267 .field("metadata", self.metadata())
268 .finish()
269 }
270}