gpui/
executor.rs

1use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
2use async_task::Runnable;
3use futures::channel::mpsc;
4use parking_lot::{Condvar, Mutex};
5use smol::prelude::*;
6use std::{
7    fmt::Debug,
8    marker::PhantomData,
9    mem::{self, ManuallyDrop},
10    num::NonZeroUsize,
11    panic::Location,
12    pin::Pin,
13    rc::Rc,
14    sync::{
15        Arc,
16        atomic::{AtomicUsize, Ordering},
17    },
18    task::{Context, Poll},
19    thread::{self, ThreadId},
20    time::{Duration, Instant},
21};
22use util::TryFutureExt as _;
23use waker_fn::waker_fn;
24
25#[cfg(any(test, feature = "test-support"))]
26use rand::rngs::StdRng;
27
28/// A pointer to the executor that is currently running,
29/// for spawning background tasks.
30#[derive(Clone)]
31pub struct BackgroundExecutor {
32    #[doc(hidden)]
33    pub dispatcher: Arc<dyn PlatformDispatcher>,
34}
35
36/// A pointer to the executor that is currently running,
37/// for spawning tasks on the main thread.
38///
39/// This is intentionally `!Send` via the `not_send` marker field. This is because
40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
41/// only polled from the same thread it was spawned from. These checks would fail when spawning
42/// foreground tasks from background threads.
43#[derive(Clone)]
44pub struct ForegroundExecutor {
45    #[doc(hidden)]
46    pub dispatcher: Arc<dyn PlatformDispatcher>,
47    liveness: std::sync::Weak<()>,
48    not_send: PhantomData<Rc<()>>,
49}
50
51/// Realtime task priority
52#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
53#[repr(u8)]
54pub enum RealtimePriority {
55    /// Audio task
56    Audio,
57    /// Other realtime task
58    #[default]
59    Other,
60}
61
62/// Task priority
63#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
64#[repr(u8)]
65pub enum Priority {
66    /// Realtime priority
67    ///
68    /// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
69    Realtime(RealtimePriority),
70    /// High priority
71    ///
72    /// Only use for tasks that are critical to the user experience / responsiveness of the editor.
73    High,
74    /// Medium priority, probably suits most of your use cases.
75    #[default]
76    Medium,
77    /// Low priority
78    ///
79    /// Prioritize this for background work that can come in large quantities
80    /// to not starve the executor of resources for high priority tasks
81    Low,
82}
83
84impl Priority {
85    #[allow(dead_code)]
86    pub(crate) const fn probability(&self) -> u32 {
87        match self {
88            // realtime priorities are not considered for probability scheduling
89            Priority::Realtime(_) => 0,
90            Priority::High => 60,
91            Priority::Medium => 30,
92            Priority::Low => 10,
93        }
94    }
95}
96
97/// Task is a primitive that allows work to happen in the background.
98///
99/// It implements [`Future`] so you can `.await` on it.
100///
101/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
102/// the task to continue running, but with no way to return a value.
103#[must_use]
104#[derive(Debug)]
105pub struct Task<T>(TaskState<T>);
106
107#[derive(Debug)]
108enum TaskState<T> {
109    /// A task that is ready to return a value
110    Ready(Option<T>),
111
112    /// A task that is currently running.
113    Spawned(async_task::Task<T, RunnableMeta>),
114}
115
116impl<T> Task<T> {
117    /// Creates a new task that will resolve with the value
118    pub fn ready(val: T) -> Self {
119        Task(TaskState::Ready(Some(val)))
120    }
121
122    /// Detaching a task runs it to completion in the background
123    pub fn detach(self) {
124        match self {
125            Task(TaskState::Ready(_)) => {}
126            Task(TaskState::Spawned(task)) => task.detach(),
127        }
128    }
129
130    /// Converts this task into a fallible task that returns `Option<T>`.
131    ///
132    /// Unlike the standard `Task<T>`, a [`FallibleTask`] will return `None`
133    /// if the app was dropped while the task is executing.
134    ///
135    /// # Example
136    ///
137    /// ```ignore
138    /// // Background task that gracefully handles app shutdown:
139    /// cx.background_spawn(async move {
140    ///     let result = foreground_task.fallible().await;
141    ///     if let Some(value) = result {
142    ///         // Process the value
143    ///     }
144    ///     // If None, app was shut down - just exit gracefully
145    /// }).detach();
146    /// ```
147    pub fn fallible(self) -> FallibleTask<T> {
148        FallibleTask(match self.0 {
149            TaskState::Ready(val) => FallibleTaskState::Ready(val),
150            TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
151        })
152    }
153}
154
155impl<E, T> Task<Result<T, E>>
156where
157    T: 'static,
158    E: 'static + Debug,
159{
160    /// Run the task to completion in the background and log any
161    /// errors that occur.
162    #[track_caller]
163    pub fn detach_and_log_err(self, cx: &App) {
164        let location = core::panic::Location::caller();
165        cx.foreground_executor()
166            .spawn(self.log_tracked_err(*location))
167            .detach();
168    }
169}
170
171impl<T> Future for Task<T> {
172    type Output = T;
173
174    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
175        match unsafe { self.get_unchecked_mut() } {
176            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
177            Task(TaskState::Spawned(task)) => task.poll(cx),
178        }
179    }
180}
181
182/// A task that returns `Option<T>` instead of panicking when cancelled.
183#[must_use]
184pub struct FallibleTask<T>(FallibleTaskState<T>);
185
186enum FallibleTaskState<T> {
187    /// A task that is ready to return a value
188    Ready(Option<T>),
189
190    /// A task that is currently running (wraps async_task::FallibleTask).
191    Spawned(async_task::FallibleTask<T, RunnableMeta>),
192}
193
194impl<T> FallibleTask<T> {
195    /// Creates a new fallible task that will resolve with the value.
196    pub fn ready(val: T) -> Self {
197        FallibleTask(FallibleTaskState::Ready(Some(val)))
198    }
199
200    /// Detaching a task runs it to completion in the background.
201    pub fn detach(self) {
202        match self.0 {
203            FallibleTaskState::Ready(_) => {}
204            FallibleTaskState::Spawned(task) => task.detach(),
205        }
206    }
207}
208
209impl<T> Future for FallibleTask<T> {
210    type Output = Option<T>;
211
212    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
213        match unsafe { self.get_unchecked_mut() } {
214            FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
215            FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
216        }
217    }
218}
219
220impl<T> std::fmt::Debug for FallibleTask<T> {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match &self.0 {
223            FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
224            FallibleTaskState::Spawned(task) => {
225                f.debug_tuple("FallibleTask::Spawned").field(task).finish()
226            }
227        }
228    }
229}
230
231/// A task label is an opaque identifier that you can use to
232/// refer to a task in tests.
233#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
234pub struct TaskLabel(NonZeroUsize);
235
236impl Default for TaskLabel {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl TaskLabel {
243    /// Construct a new task label.
244    pub fn new() -> Self {
245        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
246        Self(
247            NEXT_TASK_LABEL
248                .fetch_add(1, Ordering::SeqCst)
249                .try_into()
250                .unwrap(),
251        )
252    }
253}
254
255type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
256
257type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
258
259/// BackgroundExecutor lets you run things on background threads.
260/// In production this is a thread pool with no ordering guarantees.
261/// In tests this is simulated by running tasks one by one in a deterministic
262/// (but arbitrary) order controlled by the `SEED` environment variable.
263impl BackgroundExecutor {
264    #[doc(hidden)]
265    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
266        Self { dispatcher }
267    }
268
269    /// Enqueues the given future to be run to completion on a background thread.
270    #[track_caller]
271    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
272    where
273        R: Send + 'static,
274    {
275        self.spawn_with_priority(Priority::default(), future)
276    }
277
278    /// Enqueues the given future to be run to completion on a background thread.
279    #[track_caller]
280    pub fn spawn_with_priority<R>(
281        &self,
282        priority: Priority,
283        future: impl Future<Output = R> + Send + 'static,
284    ) -> Task<R>
285    where
286        R: Send + 'static,
287    {
288        self.spawn_internal::<R>(Box::pin(future), None, priority)
289    }
290
291    /// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
292    ///
293    /// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
294    /// completion before the current task is resumed, even if the current task is slated for cancellation.
295    pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
296    where
297        R: Send,
298    {
299        // We need to ensure that cancellation of the parent task does not drop the environment
300        // before the our own task has completed or got cancelled.
301        struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
302
303        impl Drop for NotifyOnDrop<'_> {
304            fn drop(&mut self) {
305                *self.0.1.lock() = true;
306                self.0.0.notify_all();
307            }
308        }
309
310        struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
311
312        impl Drop for WaitOnDrop<'_> {
313            fn drop(&mut self) {
314                let mut done = self.0.1.lock();
315                if !*done {
316                    self.0.0.wait(&mut done);
317                }
318            }
319        }
320
321        let dispatcher = self.dispatcher.clone();
322        let location = core::panic::Location::caller();
323
324        let pair = &(Condvar::new(), Mutex::new(false));
325        let _wait_guard = WaitOnDrop(pair);
326
327        let (runnable, task) = unsafe {
328            async_task::Builder::new()
329                .metadata(RunnableMeta {
330                    location,
331                    app: None,
332                })
333                .spawn_unchecked(
334                    move |_| async {
335                        let _notify_guard = NotifyOnDrop(pair);
336                        future.await
337                    },
338                    move |runnable| {
339                        dispatcher.dispatch(
340                            RunnableVariant::Meta(runnable),
341                            None,
342                            Priority::default(),
343                        )
344                    },
345                )
346        };
347        runnable.schedule();
348        task.await
349    }
350
351    /// Enqueues the given future to be run to completion on a background thread.
352    /// The given label can be used to control the priority of the task in tests.
353    #[track_caller]
354    pub fn spawn_labeled<R>(
355        &self,
356        label: TaskLabel,
357        future: impl Future<Output = R> + Send + 'static,
358    ) -> Task<R>
359    where
360        R: Send + 'static,
361    {
362        self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
363    }
364
365    #[track_caller]
366    fn spawn_internal<R: Send + 'static>(
367        &self,
368        future: AnyFuture<R>,
369        label: Option<TaskLabel>,
370        priority: Priority,
371    ) -> Task<R> {
372        let dispatcher = self.dispatcher.clone();
373        let (runnable, task) = if let Priority::Realtime(realtime) = priority {
374            let location = core::panic::Location::caller();
375            let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
376
377            dispatcher.spawn_realtime(
378                realtime,
379                Box::new(move || {
380                    while let Ok(runnable) = rx.recv() {
381                        let start = Instant::now();
382                        let location = runnable.metadata().location;
383                        let mut timing = TaskTiming {
384                            location,
385                            start,
386                            end: None,
387                        };
388                        profiler::add_task_timing(timing);
389
390                        runnable.run();
391
392                        let end = Instant::now();
393                        timing.end = Some(end);
394                        profiler::add_task_timing(timing);
395                    }
396                }),
397            );
398
399            async_task::Builder::new()
400                .metadata(RunnableMeta {
401                    location,
402                    app: None,
403                })
404                .spawn(
405                    move |_| future,
406                    move |runnable| {
407                        let _ = tx.send(runnable);
408                    },
409                )
410        } else {
411            let location = core::panic::Location::caller();
412            async_task::Builder::new()
413                .metadata(RunnableMeta {
414                    location,
415                    app: None,
416                })
417                .spawn(
418                    move |_| future,
419                    move |runnable| {
420                        dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
421                    },
422                )
423        };
424
425        runnable.schedule();
426        Task(TaskState::Spawned(task))
427    }
428
429    /// Used by the test harness to run an async test in a synchronous fashion.
430    #[cfg(any(test, feature = "test-support"))]
431    #[track_caller]
432    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
433        if let Ok(value) = self.block_internal(false, future, None) {
434            value
435        } else {
436            unreachable!()
437        }
438    }
439
440    /// Block the current thread until the given future resolves.
441    /// Consider using `block_with_timeout` instead.
442    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
443        if let Ok(value) = self.block_internal(true, future, None) {
444            value
445        } else {
446            unreachable!()
447        }
448    }
449
450    #[cfg(not(any(test, feature = "test-support")))]
451    pub(crate) fn block_internal<Fut: Future>(
452        &self,
453        _background_only: bool,
454        future: Fut,
455        timeout: Option<Duration>,
456    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
457        use std::time::Instant;
458
459        let mut future = Box::pin(future);
460        if timeout == Some(Duration::ZERO) {
461            return Err(future);
462        }
463        let deadline = timeout.map(|timeout| Instant::now() + timeout);
464
465        let parker = parking::Parker::new();
466        let unparker = parker.unparker();
467        let waker = waker_fn(move || {
468            unparker.unpark();
469        });
470        let mut cx = std::task::Context::from_waker(&waker);
471
472        loop {
473            match future.as_mut().poll(&mut cx) {
474                Poll::Ready(result) => return Ok(result),
475                Poll::Pending => {
476                    let timeout =
477                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
478                    if let Some(timeout) = timeout {
479                        if !parker.park_timeout(timeout)
480                            && deadline.is_some_and(|deadline| deadline < Instant::now())
481                        {
482                            return Err(future);
483                        }
484                    } else {
485                        parker.park();
486                    }
487                }
488            }
489        }
490    }
491
492    #[cfg(any(test, feature = "test-support"))]
493    #[track_caller]
494    pub(crate) fn block_internal<Fut: Future>(
495        &self,
496        background_only: bool,
497        future: Fut,
498        timeout: Option<Duration>,
499    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
500        use std::sync::atomic::AtomicBool;
501        use std::time::Instant;
502
503        use parking::Parker;
504
505        let mut future = Box::pin(future);
506        if timeout == Some(Duration::ZERO) {
507            return Err(future);
508        }
509
510        // When using a real platform (e.g., MacPlatform for visual tests that need actual
511        // Metal rendering), there's no test dispatcher. In this case, we block the thread
512        // directly by polling the future and parking until woken. This is required for
513        // VisualTestAppContext which uses real platform rendering but still needs blocking
514        // behavior for code paths like editor initialization that call block_with_timeout.
515        let Some(dispatcher) = self.dispatcher.as_test() else {
516            let deadline = timeout.map(|timeout| Instant::now() + timeout);
517
518            let parker = Parker::new();
519            let unparker = parker.unparker();
520            let waker = waker_fn(move || {
521                unparker.unpark();
522            });
523            let mut cx = std::task::Context::from_waker(&waker);
524
525            loop {
526                match future.as_mut().poll(&mut cx) {
527                    Poll::Ready(result) => return Ok(result),
528                    Poll::Pending => {
529                        let timeout = deadline
530                            .map(|deadline| deadline.saturating_duration_since(Instant::now()));
531                        if let Some(timeout) = timeout {
532                            if !parker.park_timeout(timeout)
533                                && deadline.is_some_and(|deadline| deadline < Instant::now())
534                            {
535                                return Err(future);
536                            }
537                        } else {
538                            parker.park();
539                        }
540                    }
541                }
542            }
543        };
544
545        let mut max_ticks = if timeout.is_some() {
546            dispatcher.gen_block_on_ticks()
547        } else {
548            usize::MAX
549        };
550
551        let parker = Parker::new();
552        let unparker = parker.unparker();
553
554        let awoken = Arc::new(AtomicBool::new(false));
555        let waker = waker_fn({
556            let awoken = awoken.clone();
557            let unparker = unparker.clone();
558            move || {
559                awoken.store(true, Ordering::SeqCst);
560                unparker.unpark();
561            }
562        });
563        let mut cx = std::task::Context::from_waker(&waker);
564
565        let duration = Duration::from_secs(
566            option_env!("GPUI_TEST_TIMEOUT")
567                .and_then(|s| s.parse::<u64>().ok())
568                .unwrap_or(180),
569        );
570        let mut test_should_end_by = Instant::now() + duration;
571
572        loop {
573            match future.as_mut().poll(&mut cx) {
574                Poll::Ready(result) => return Ok(result),
575                Poll::Pending => {
576                    if max_ticks == 0 {
577                        return Err(future);
578                    }
579                    max_ticks -= 1;
580
581                    if !dispatcher.tick(background_only) {
582                        if awoken.swap(false, Ordering::SeqCst) {
583                            continue;
584                        }
585
586                        if !dispatcher.parking_allowed() {
587                            if dispatcher.advance_clock_to_next_delayed() {
588                                continue;
589                            }
590                            let mut backtrace_message = String::new();
591                            let mut waiting_message = String::new();
592                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
593                                backtrace_message =
594                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
595                            }
596                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
597                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
598                            }
599                            panic!(
600                                "parked with nothing left to run{waiting_message}{backtrace_message}",
601                            )
602                        }
603                        dispatcher.push_unparker(unparker.clone());
604                        parker.park_timeout(Duration::from_millis(1));
605                        if Instant::now() > test_should_end_by {
606                            panic!("test timed out after {duration:?} with allow_parking")
607                        }
608                    }
609                }
610            }
611        }
612    }
613
614    /// Block the current thread until the given future resolves
615    /// or `duration` has elapsed.
616    pub fn block_with_timeout<Fut: Future>(
617        &self,
618        duration: Duration,
619        future: Fut,
620    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
621        self.block_internal(true, future, Some(duration))
622    }
623
624    /// Scoped lets you start a number of tasks and waits
625    /// for all of them to complete before returning.
626    pub async fn scoped<'scope, F>(&self, scheduler: F)
627    where
628        F: FnOnce(&mut Scope<'scope>),
629    {
630        let mut scope = Scope::new(self.clone(), Priority::default());
631        (scheduler)(&mut scope);
632        let spawned = mem::take(&mut scope.futures)
633            .into_iter()
634            .map(|f| self.spawn_with_priority(scope.priority, f))
635            .collect::<Vec<_>>();
636        for task in spawned {
637            task.await;
638        }
639    }
640
641    /// Scoped lets you start a number of tasks and waits
642    /// for all of them to complete before returning.
643    pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
644    where
645        F: FnOnce(&mut Scope<'scope>),
646    {
647        let mut scope = Scope::new(self.clone(), priority);
648        (scheduler)(&mut scope);
649        let spawned = mem::take(&mut scope.futures)
650            .into_iter()
651            .map(|f| self.spawn_with_priority(scope.priority, f))
652            .collect::<Vec<_>>();
653        for task in spawned {
654            task.await;
655        }
656    }
657
658    /// Get the current time.
659    ///
660    /// Calling this instead of `std::time::Instant::now` allows the use
661    /// of fake timers in tests.
662    pub fn now(&self) -> Instant {
663        self.dispatcher.now()
664    }
665
666    /// Returns a task that will complete after the given duration.
667    /// Depending on other concurrent tasks the elapsed duration may be longer
668    /// than requested.
669    pub fn timer(&self, duration: Duration) -> Task<()> {
670        if duration.is_zero() {
671            return Task::ready(());
672        }
673        let location = core::panic::Location::caller();
674        let (runnable, task) = async_task::Builder::new()
675            .metadata(RunnableMeta {
676                location,
677                app: None,
678            })
679            .spawn(move |_| async move {}, {
680                let dispatcher = self.dispatcher.clone();
681                move |runnable| dispatcher.dispatch_after(duration, RunnableVariant::Meta(runnable))
682            });
683        runnable.schedule();
684        Task(TaskState::Spawned(task))
685    }
686
687    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
688    #[cfg(any(test, feature = "test-support"))]
689    pub fn start_waiting(&self) {
690        self.dispatcher.as_test().unwrap().start_waiting();
691    }
692
693    /// in tests, removes the debugging data added by start_waiting
694    #[cfg(any(test, feature = "test-support"))]
695    pub fn finish_waiting(&self) {
696        self.dispatcher.as_test().unwrap().finish_waiting();
697    }
698
699    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
700    #[cfg(any(test, feature = "test-support"))]
701    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
702        self.dispatcher.as_test().unwrap().simulate_random_delay()
703    }
704
705    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
706    #[cfg(any(test, feature = "test-support"))]
707    pub fn deprioritize(&self, task_label: TaskLabel) {
708        self.dispatcher.as_test().unwrap().deprioritize(task_label)
709    }
710
711    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
712    #[cfg(any(test, feature = "test-support"))]
713    pub fn advance_clock(&self, duration: Duration) {
714        self.dispatcher.as_test().unwrap().advance_clock(duration)
715    }
716
717    /// in tests, run one task.
718    #[cfg(any(test, feature = "test-support"))]
719    pub fn tick(&self) -> bool {
720        self.dispatcher.as_test().unwrap().tick(false)
721    }
722
723    /// in tests, run all tasks that are ready to run. If after doing so
724    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
725    #[cfg(any(test, feature = "test-support"))]
726    pub fn run_until_parked(&self) {
727        self.dispatcher.as_test().unwrap().run_until_parked()
728    }
729
730    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
731    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
732    /// do take real async time to run.
733    #[cfg(any(test, feature = "test-support"))]
734    pub fn allow_parking(&self) {
735        self.dispatcher.as_test().unwrap().allow_parking();
736    }
737
738    /// undoes the effect of [`Self::allow_parking`].
739    #[cfg(any(test, feature = "test-support"))]
740    pub fn forbid_parking(&self) {
741        self.dispatcher.as_test().unwrap().forbid_parking();
742    }
743
744    /// adds detail to the "parked with nothing let to run" message.
745    #[cfg(any(test, feature = "test-support"))]
746    pub fn set_waiting_hint(&self, msg: Option<String>) {
747        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
748    }
749
750    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
751    #[cfg(any(test, feature = "test-support"))]
752    pub fn rng(&self) -> StdRng {
753        self.dispatcher.as_test().unwrap().rng()
754    }
755
756    /// How many CPUs are available to the dispatcher.
757    pub fn num_cpus(&self) -> usize {
758        #[cfg(any(test, feature = "test-support"))]
759        return 4;
760
761        #[cfg(not(any(test, feature = "test-support")))]
762        return num_cpus::get();
763    }
764
765    /// Whether we're on the main thread.
766    pub fn is_main_thread(&self) -> bool {
767        self.dispatcher.is_main_thread()
768    }
769
770    #[cfg(any(test, feature = "test-support"))]
771    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
772    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
773        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
774    }
775}
776
777/// ForegroundExecutor runs things on the main thread.
778impl ForegroundExecutor {
779    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
780    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>, liveness: std::sync::Weak<()>) -> Self {
781        Self {
782            dispatcher,
783            liveness,
784            not_send: PhantomData,
785        }
786    }
787
788    /// Enqueues the given Task to run on the main thread at some point in the future.
789    #[track_caller]
790    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
791    where
792        R: 'static,
793    {
794        self.inner_spawn(self.liveness.clone(), Priority::default(), future)
795    }
796
797    /// Enqueues the given Task to run on the main thread at some point in the future.
798    #[track_caller]
799    pub fn spawn_with_priority<R>(
800        &self,
801        priority: Priority,
802        future: impl Future<Output = R> + 'static,
803    ) -> Task<R>
804    where
805        R: 'static,
806    {
807        self.inner_spawn(self.liveness.clone(), priority, future)
808    }
809
810    #[track_caller]
811    pub(crate) fn inner_spawn<R>(
812        &self,
813        app: std::sync::Weak<()>,
814        priority: Priority,
815        future: impl Future<Output = R> + 'static,
816    ) -> Task<R>
817    where
818        R: 'static,
819    {
820        let dispatcher = self.dispatcher.clone();
821        let location = core::panic::Location::caller();
822
823        #[track_caller]
824        fn inner<R: 'static>(
825            dispatcher: Arc<dyn PlatformDispatcher>,
826            future: AnyLocalFuture<R>,
827            location: &'static core::panic::Location<'static>,
828            app: std::sync::Weak<()>,
829            priority: Priority,
830        ) -> Task<R> {
831            let (runnable, task) = spawn_local_with_source_location(
832                future,
833                move |runnable| {
834                    dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
835                },
836                RunnableMeta {
837                    location,
838                    app: Some(app),
839                },
840            );
841            runnable.schedule();
842            Task(TaskState::Spawned(task))
843        }
844        inner::<R>(dispatcher, Box::pin(future), location, app, priority)
845    }
846}
847
848/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
849///
850/// Copy-modified from:
851/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
852#[track_caller]
853fn spawn_local_with_source_location<Fut, S, M>(
854    future: Fut,
855    schedule: S,
856    metadata: M,
857) -> (Runnable<M>, async_task::Task<Fut::Output, M>)
858where
859    Fut: Future + 'static,
860    Fut::Output: 'static,
861    S: async_task::Schedule<M> + Send + Sync + 'static,
862    M: 'static,
863{
864    #[inline]
865    fn thread_id() -> ThreadId {
866        std::thread_local! {
867            static ID: ThreadId = thread::current().id();
868        }
869        ID.try_with(|id| *id)
870            .unwrap_or_else(|_| thread::current().id())
871    }
872
873    struct Checked<F> {
874        id: ThreadId,
875        inner: ManuallyDrop<F>,
876        location: &'static Location<'static>,
877    }
878
879    impl<F> Drop for Checked<F> {
880        fn drop(&mut self) {
881            assert!(
882                self.id == thread_id(),
883                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
884                self.location
885            );
886            unsafe { ManuallyDrop::drop(&mut self.inner) };
887        }
888    }
889
890    impl<F: Future> Future for Checked<F> {
891        type Output = F::Output;
892
893        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
894            assert!(
895                self.id == thread_id(),
896                "local task polled by a thread that didn't spawn it. Task spawned at {}",
897                self.location
898            );
899            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
900        }
901    }
902
903    // Wrap the future into one that checks which thread it's on.
904    let future = Checked {
905        id: thread_id(),
906        inner: ManuallyDrop::new(future),
907        location: Location::caller(),
908    };
909
910    unsafe {
911        async_task::Builder::new()
912            .metadata(metadata)
913            .spawn_unchecked(move |_| future, schedule)
914    }
915}
916
917/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
918pub struct Scope<'a> {
919    executor: BackgroundExecutor,
920    priority: Priority,
921    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
922    tx: Option<mpsc::Sender<()>>,
923    rx: mpsc::Receiver<()>,
924    lifetime: PhantomData<&'a ()>,
925}
926
927impl<'a> Scope<'a> {
928    fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
929        let (tx, rx) = mpsc::channel(1);
930        Self {
931            executor,
932            priority,
933            tx: Some(tx),
934            rx,
935            futures: Default::default(),
936            lifetime: PhantomData,
937        }
938    }
939
940    /// How many CPUs are available to the dispatcher.
941    pub fn num_cpus(&self) -> usize {
942        self.executor.num_cpus()
943    }
944
945    /// Spawn a future into this scope.
946    #[track_caller]
947    pub fn spawn<F>(&mut self, f: F)
948    where
949        F: Future<Output = ()> + Send + 'a,
950    {
951        let tx = self.tx.clone().unwrap();
952
953        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
954        // dropping this `Scope` blocks until all of the futures have resolved.
955        let f = unsafe {
956            mem::transmute::<
957                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
958                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
959            >(Box::pin(async move {
960                f.await;
961                drop(tx);
962            }))
963        };
964        self.futures.push(f);
965    }
966}
967
968impl Drop for Scope<'_> {
969    fn drop(&mut self) {
970        self.tx.take().unwrap();
971
972        // Wait until the channel is closed, which means that all of the spawned
973        // futures have resolved.
974        self.executor.block(self.rx.next());
975    }
976}
977
978#[cfg(test)]
979mod test {
980    use super::*;
981    use crate::{App, TestDispatcher, TestPlatform};
982    use rand::SeedableRng;
983    use std::cell::RefCell;
984
985    /// Helper to create test infrastructure.
986    /// Returns (dispatcher, background_executor, app) where app's foreground_executor has liveness.
987    fn create_test_app() -> (TestDispatcher, BackgroundExecutor, Rc<crate::AppCell>) {
988        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
989        let arc_dispatcher = Arc::new(dispatcher.clone());
990        // Create liveness for task cancellation
991        let liveness = std::sync::Arc::new(());
992        let liveness_weak = std::sync::Arc::downgrade(&liveness);
993        let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
994        let foreground_executor = ForegroundExecutor::new(arc_dispatcher, liveness_weak);
995
996        let platform = TestPlatform::new(background_executor.clone(), foreground_executor);
997        let asset_source = Arc::new(());
998        let http_client = http_client::FakeHttpClient::with_404_response();
999
1000        let app = App::new_app(platform, liveness, asset_source, http_client);
1001        (dispatcher, background_executor, app)
1002    }
1003
1004    #[test]
1005    fn sanity_test_tasks_run() {
1006        let (dispatcher, _background_executor, app) = create_test_app();
1007        let foreground_executor = app.borrow().foreground_executor.clone();
1008
1009        let task_ran = Rc::new(RefCell::new(false));
1010
1011        foreground_executor
1012            .spawn({
1013                let task_ran = Rc::clone(&task_ran);
1014                async move {
1015                    *task_ran.borrow_mut() = true;
1016                }
1017            })
1018            .detach();
1019
1020        // Run dispatcher while app is still alive
1021        dispatcher.run_until_parked();
1022
1023        // Task should have run
1024        assert!(
1025            *task_ran.borrow(),
1026            "Task should run normally when app is alive"
1027        );
1028    }
1029
1030    #[test]
1031    fn test_task_cancelled_when_app_dropped() {
1032        let (dispatcher, _background_executor, app) = create_test_app();
1033        let foreground_executor = app.borrow().foreground_executor.clone();
1034        let app_weak = Rc::downgrade(&app);
1035
1036        let task_ran = Rc::new(RefCell::new(false));
1037        let task_ran_clone = Rc::clone(&task_ran);
1038
1039        foreground_executor
1040            .spawn(async move {
1041                *task_ran_clone.borrow_mut() = true;
1042            })
1043            .detach();
1044
1045        drop(app);
1046
1047        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
1048
1049        dispatcher.run_until_parked();
1050
1051        // The task should have been cancelled, not run
1052        assert!(
1053            !*task_ran.borrow(),
1054            "Task should have been cancelled when app was dropped, but it ran!"
1055        );
1056    }
1057
1058    #[test]
1059    fn test_nested_tasks_both_cancel() {
1060        let (dispatcher, _background_executor, app) = create_test_app();
1061        let foreground_executor = app.borrow().foreground_executor.clone();
1062        let app_weak = Rc::downgrade(&app);
1063
1064        let outer_completed = Rc::new(RefCell::new(false));
1065        let inner_completed = Rc::new(RefCell::new(false));
1066        let reached_await = Rc::new(RefCell::new(false));
1067
1068        let outer_flag = Rc::clone(&outer_completed);
1069        let inner_flag = Rc::clone(&inner_completed);
1070        let await_flag = Rc::clone(&reached_await);
1071
1072        // Channel to block the inner task until we're ready
1073        let (tx, rx) = futures::channel::oneshot::channel::<()>();
1074
1075        let inner_executor = foreground_executor.clone();
1076
1077        foreground_executor
1078            .spawn(async move {
1079                let inner_task = inner_executor.spawn({
1080                    let inner_flag = Rc::clone(&inner_flag);
1081                    async move {
1082                        rx.await.ok();
1083                        *inner_flag.borrow_mut() = true;
1084                    }
1085                });
1086
1087                *await_flag.borrow_mut() = true;
1088
1089                inner_task.await;
1090
1091                *outer_flag.borrow_mut() = true;
1092            })
1093            .detach();
1094
1095        // Run dispatcher until outer task reaches the await point
1096        // The inner task will be blocked on the channel
1097        dispatcher.run_until_parked();
1098
1099        // Verify we actually reached the await point before dropping the app
1100        assert!(
1101            *reached_await.borrow(),
1102            "Outer task should have reached the await point"
1103        );
1104
1105        // Neither task should have completed yet
1106        assert!(
1107            !*outer_completed.borrow(),
1108            "Outer task should not have completed yet"
1109        );
1110        assert!(
1111            !*inner_completed.borrow(),
1112            "Inner task should not have completed yet"
1113        );
1114
1115        // Drop the channel sender and app while outer is awaiting inner
1116        drop(tx);
1117        drop(app);
1118        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
1119
1120        // Run dispatcher - both tasks should be cancelled
1121        dispatcher.run_until_parked();
1122
1123        // Neither task should have completed (both were cancelled)
1124        assert!(
1125            !*outer_completed.borrow(),
1126            "Outer task should have been cancelled, not completed"
1127        );
1128        assert!(
1129            !*inner_completed.borrow(),
1130            "Inner task should have been cancelled, not completed"
1131        );
1132    }
1133
1134    #[test]
1135    #[should_panic]
1136    fn test_polling_cancelled_task_panics() {
1137        let (dispatcher, background_executor, app) = create_test_app();
1138        let foreground_executor = app.borrow().foreground_executor.clone();
1139        let app_weak = Rc::downgrade(&app);
1140
1141        let task = foreground_executor.spawn(async move { 42 });
1142
1143        drop(app);
1144
1145        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
1146
1147        dispatcher.run_until_parked();
1148
1149        background_executor.block(task);
1150    }
1151
1152    #[test]
1153    fn test_polling_cancelled_task_returns_none_with_fallible() {
1154        let (dispatcher, background_executor, app) = create_test_app();
1155        let foreground_executor = app.borrow().foreground_executor.clone();
1156        let app_weak = Rc::downgrade(&app);
1157
1158        let task = foreground_executor.spawn(async move { 42 }).fallible();
1159
1160        drop(app);
1161
1162        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
1163
1164        dispatcher.run_until_parked();
1165
1166        let result = background_executor.block(task);
1167        assert_eq!(result, None, "Cancelled task should return None");
1168    }
1169}