1use crate::{Instant, Priority, RunnableMeta, Scheduler, SessionId, Timer};
2use async_task::Runnable;
3use std::{
4 any::Any,
5 future::Future,
6 marker::PhantomData,
7 mem::ManuallyDrop,
8 panic::Location,
9 pin::Pin,
10 rc::Rc,
11 sync::Arc,
12 task::{Context, Poll},
13 thread::{self, ThreadId},
14 time::Duration,
15};
16
17#[derive(Clone)]
22pub struct LocalExecutor {
23 session_id: SessionId,
24 scheduler: Arc<dyn Scheduler>,
25 dispatch: Arc<dyn Fn(Runnable<RunnableMeta>) + Send + Sync>,
29 not_send: PhantomData<Rc<()>>,
30}
31
32impl LocalExecutor {
33 pub fn new(
42 session_id: SessionId,
43 scheduler: Arc<dyn Scheduler>,
44 dispatch: impl Fn(Runnable<RunnableMeta>) + Send + Sync + 'static,
45 ) -> Self {
46 Self {
47 session_id,
48 scheduler,
49 dispatch: Arc::new(dispatch),
50 not_send: PhantomData,
51 }
52 }
53
54 pub fn session_id(&self) -> SessionId {
55 self.session_id
56 }
57
58 pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
59 &self.scheduler
60 }
61
62 #[track_caller]
63 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
64 where
65 F: Future + 'static,
66 F::Output: 'static,
67 {
68 let dispatch = self.dispatch.clone();
69 let location = Location::caller();
70 let (runnable, task) = spawn_local_with_source_location(
71 future,
72 move |runnable| dispatch(runnable),
73 RunnableMeta {
74 location,
75 spawned: crate::SpawnTime(Instant::now()),
76 },
77 );
78 runnable.schedule();
79 Task(TaskState::Spawned(task))
80 }
81
82 pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
83 use std::cell::Cell;
84
85 let output = Cell::new(None);
86 let future = async {
87 output.set(Some(future.await));
88 };
89 let mut future = std::pin::pin!(future);
90
91 self.scheduler
92 .block(Some(self.session_id), future.as_mut(), None);
93
94 output.take().expect("block_on future did not complete")
95 }
96
97 pub fn block_with_timeout<Fut: Future>(
100 &self,
101 timeout: Duration,
102 future: Fut,
103 ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
104 use std::cell::Cell;
105
106 let output = Cell::new(None);
107 let mut future = Box::pin(future);
108
109 {
110 let future_ref = &mut future;
111 let wrapper = async {
112 output.set(Some(future_ref.await));
113 };
114 let mut wrapper = std::pin::pin!(wrapper);
115
116 self.scheduler
117 .block(Some(self.session_id), wrapper.as_mut(), Some(timeout));
118 }
119
120 match output.take() {
121 Some(value) => Ok(value),
122 None => Err(future),
123 }
124 }
125
126 #[track_caller]
127 pub fn timer(&self, duration: Duration) -> Timer {
128 self.scheduler.timer(duration)
129 }
130
131 pub fn now(&self) -> Instant {
132 self.scheduler.clock().now()
133 }
134
135 #[track_caller]
144 pub fn spawn_dedicated<F, Fut>(&self, f: F) -> Task<Fut::Output>
145 where
146 F: FnOnce(LocalExecutor) -> Fut + Send + 'static,
147 Fut: Future + 'static,
148 Fut::Output: Send + Sync + 'static,
149 {
150 self.scheduler
151 .clone()
152 .spawn_dedicated(box_dedicated(f))
153 .downcast::<Fut::Output>()
154 }
155}
156
157fn box_dedicated<F, Fut>(
162 f: F,
163) -> Box<
164 dyn FnOnce(LocalExecutor) -> Pin<Box<dyn Future<Output = Box<dyn Any + Send + Sync>> + 'static>>
165 + Send
166 + 'static,
167>
168where
169 F: FnOnce(LocalExecutor) -> Fut + Send + 'static,
170 Fut: Future + 'static,
171 Fut::Output: Send + Sync + 'static,
172{
173 Box::new(move |executor| {
174 Box::pin(async move { Box::new(f(executor).await) as Box<dyn Any + Send + Sync> })
175 })
176}
177
178#[derive(Clone)]
179pub struct BackgroundExecutor {
180 scheduler: Arc<dyn Scheduler>,
181}
182
183impl BackgroundExecutor {
184 pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
185 Self { scheduler }
186 }
187
188 #[track_caller]
189 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
190 where
191 F: Future + Send + 'static,
192 F::Output: Send + 'static,
193 {
194 self.spawn_with_priority(Priority::default(), future)
195 }
196
197 #[track_caller]
198 pub fn spawn_with_priority<F>(&self, priority: Priority, future: F) -> Task<F::Output>
199 where
200 F: Future + Send + 'static,
201 F::Output: Send + 'static,
202 {
203 let scheduler = Arc::downgrade(&self.scheduler);
204 let location = Location::caller();
205 let (runnable, task) = async_task::Builder::new()
206 .metadata(RunnableMeta {
207 location,
208 spawned: crate::SpawnTime(Instant::now()),
209 })
210 .spawn(
211 move |_| future,
212 move |runnable| {
213 if let Some(scheduler) = scheduler.upgrade() {
214 scheduler.schedule_background_with_priority(runnable, priority);
215 }
216 },
217 );
218 runnable.schedule();
219 Task(TaskState::Spawned(task))
220 }
221
222 #[track_caller]
224 pub fn spawn_realtime<F>(&self, future: F) -> Task<F::Output>
225 where
226 F: Future + Send + 'static,
227 F::Output: Send + 'static,
228 {
229 let location = Location::caller();
230 let (tx, rx) = flume::bounded::<async_task::Runnable<RunnableMeta>>(1);
231
232 self.scheduler.spawn_realtime(Box::new(move || {
233 while let Ok(runnable) = rx.recv() {
234 runnable.run();
235 }
236 }));
237
238 let (runnable, task) = async_task::Builder::new()
239 .metadata(RunnableMeta {
240 location,
241 spawned: crate::SpawnTime(Instant::now()),
242 })
243 .spawn(
244 move |_| future,
245 move |runnable| {
246 let _ = tx.send(runnable);
247 },
248 );
249 runnable.schedule();
250 Task(TaskState::Spawned(task))
251 }
252
253 #[track_caller]
254 pub fn timer(&self, duration: Duration) -> Timer {
255 self.scheduler.timer(duration)
256 }
257
258 pub fn now(&self) -> Instant {
259 self.scheduler.clock().now()
260 }
261
262 pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
263 &self.scheduler
264 }
265
266 #[track_caller]
275 pub fn spawn_dedicated<F, Fut>(&self, f: F) -> Task<Fut::Output>
276 where
277 F: FnOnce(LocalExecutor) -> Fut + Send + 'static,
278 Fut: Future + 'static,
279 Fut::Output: Send + Sync + 'static,
280 {
281 self.scheduler
282 .clone()
283 .spawn_dedicated(box_dedicated(f))
284 .downcast::<Fut::Output>()
285 }
286}
287
288#[must_use]
295pub struct Task<T>(TaskState<T>);
296
297enum TaskState<T> {
298 Ready(Option<T>),
300
301 Spawned(async_task::Task<T, RunnableMeta>),
303
304 Downcast {
308 inner: Box<Task<Box<dyn Any + Send + Sync>>>,
309 marker: PhantomData<fn() -> T>,
310 },
311}
312
313impl<T> Task<T> {
314 pub fn ready(val: T) -> Self {
316 Task(TaskState::Ready(Some(val)))
317 }
318
319 pub fn from_async_task(task: async_task::Task<T, RunnableMeta>) -> Self {
321 Task(TaskState::Spawned(task))
322 }
323
324 pub fn is_ready(&self) -> bool {
325 match &self.0 {
326 TaskState::Ready(_) => true,
327 TaskState::Spawned(task) => task.is_finished(),
328 TaskState::Downcast { inner, .. } => inner.is_ready(),
329 }
330 }
331
332 pub fn detach(self) {
334 match self {
335 Task(TaskState::Ready(_)) => {}
336 Task(TaskState::Spawned(task)) => task.detach(),
337 Task(TaskState::Downcast { inner, .. }) => inner.detach(),
338 }
339 }
340
341 pub fn fallible(self) -> FallibleTask<T> {
343 FallibleTask(match self.0 {
344 TaskState::Ready(val) => FallibleTaskState::Ready(val),
345 TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
346 TaskState::Downcast { inner, .. } => FallibleTaskState::Downcast {
347 inner: Box::new(inner.fallible()),
348 marker: PhantomData,
349 },
350 })
351 }
352}
353
354impl Task<Box<dyn Any + Send + Sync>> {
355 pub fn downcast<T: Send + Sync + 'static>(self) -> Task<T> {
364 Task(TaskState::Downcast {
365 inner: Box::new(self),
366 marker: PhantomData,
367 })
368 }
369}
370
371impl<T> std::fmt::Debug for Task<T> {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 match &self.0 {
374 TaskState::Ready(_) => f.debug_tuple("Task::Ready").finish(),
375 TaskState::Spawned(task) => f.debug_tuple("Task::Spawned").field(task).finish(),
376 TaskState::Downcast { inner, .. } => {
377 f.debug_tuple("Task::Downcast").field(inner).finish()
378 }
379 }
380 }
381}
382
383#[must_use]
385pub struct FallibleTask<T>(FallibleTaskState<T>);
386
387enum FallibleTaskState<T> {
388 Ready(Option<T>),
390
391 Spawned(async_task::FallibleTask<T, RunnableMeta>),
393
394 Downcast {
396 inner: Box<FallibleTask<Box<dyn Any + Send + Sync>>>,
397 marker: PhantomData<fn() -> T>,
398 },
399}
400
401impl<T> FallibleTask<T> {
402 pub fn ready(val: T) -> Self {
404 FallibleTask(FallibleTaskState::Ready(Some(val)))
405 }
406
407 pub fn detach(self) {
409 match self.0 {
410 FallibleTaskState::Ready(_) => {}
411 FallibleTaskState::Spawned(task) => task.detach(),
412 FallibleTaskState::Downcast { inner, .. } => inner.detach(),
413 }
414 }
415}
416
417impl<T: 'static> Future for FallibleTask<T> {
418 type Output = Option<T>;
419
420 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
421 match unsafe { self.get_unchecked_mut() } {
422 FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
423 FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
424 FallibleTask(FallibleTaskState::Downcast { inner, .. }) => {
425 match Pin::new(inner.as_mut()).poll(cx) {
426 Poll::Ready(Some(boxed_any)) => Poll::Ready(Some(
427 *boxed_any
428 .downcast::<T>()
429 .expect("FallibleTask::poll: downcast type mismatch"),
430 )),
431 Poll::Ready(None) => Poll::Ready(None),
432 Poll::Pending => Poll::Pending,
433 }
434 }
435 }
436 }
437}
438
439impl<T> std::fmt::Debug for FallibleTask<T> {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 match &self.0 {
442 FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
443 FallibleTaskState::Spawned(task) => {
444 f.debug_tuple("FallibleTask::Spawned").field(task).finish()
445 }
446 FallibleTaskState::Downcast { inner, .. } => f
447 .debug_tuple("FallibleTask::Downcast")
448 .field(inner)
449 .finish(),
450 }
451 }
452}
453
454impl<T: 'static> Future for Task<T> {
455 type Output = T;
456
457 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
458 match unsafe { self.get_unchecked_mut() } {
459 Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
460 Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
461 Task(TaskState::Downcast { inner, .. }) => match Pin::new(inner.as_mut()).poll(cx) {
462 Poll::Ready(boxed_any) => Poll::Ready(
463 *boxed_any
464 .downcast::<T>()
465 .expect("Task::poll: downcast type mismatch"),
466 ),
467 Poll::Pending => Poll::Pending,
468 },
469 }
470 }
471}
472
473#[track_caller]
475fn spawn_local_with_source_location<Fut, S>(
476 future: Fut,
477 schedule: S,
478 metadata: RunnableMeta,
479) -> (
480 async_task::Runnable<RunnableMeta>,
481 async_task::Task<Fut::Output, RunnableMeta>,
482)
483where
484 Fut: Future + 'static,
485 Fut::Output: 'static,
486 S: async_task::Schedule<RunnableMeta> + Send + Sync + 'static,
487{
488 #[inline]
489 fn thread_id() -> ThreadId {
490 std::thread_local! {
491 static ID: ThreadId = thread::current().id();
492 }
493 ID.try_with(|id| *id)
494 .unwrap_or_else(|_| thread::current().id())
495 }
496
497 struct Checked<F> {
498 id: ThreadId,
499 inner: ManuallyDrop<F>,
500 location: &'static Location<'static>,
501 }
502
503 impl<F> Drop for Checked<F> {
504 fn drop(&mut self) {
505 assert_eq!(
506 self.id,
507 thread_id(),
508 "local task dropped by a thread that didn't spawn it. Task spawned at {}",
509 self.location
510 );
511 unsafe { ManuallyDrop::drop(&mut self.inner) };
515 }
516 }
517
518 impl<F: Future> Future for Checked<F> {
519 type Output = F::Output;
520
521 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
522 let this = unsafe { self.get_unchecked_mut() };
526 assert!(
527 this.id == thread_id(),
528 "local task polled by a thread that didn't spawn it. Task spawned at {}",
529 this.location
530 );
531 unsafe { Pin::new_unchecked(&mut *this.inner).poll(cx) }
536 }
537 }
538
539 let location = metadata.location;
540
541 let future = move |_| Checked {
542 id: thread_id(),
543 inner: ManuallyDrop::new(future),
544 location,
545 };
546
547 let builder = async_task::Builder::new().metadata(metadata);
548 unsafe { builder.spawn_unchecked(future, schedule) }
552}