Skip to main content

open_gpui_scheduler/
executor.rs

1use crate::{Instant, Priority, RunnableMeta, Scheduler, SessionId, Timer};
2use async_task::Runnable;
3use std::{
4    any::Any,
5    future::Future,
6    marker::PhantomData,
7    mem::ManuallyDrop,
8    panic::Location,
9    pin::Pin,
10    rc::Rc,
11    sync::Arc,
12    task::{Context, Poll},
13    thread::{self, ThreadId},
14    time::Duration,
15};
16
17/// A `!Send` executor pinned to a single session. Tasks spawned on it run in
18/// order on whichever thread drains the dispatch destination supplied at
19/// construction time — typically the main thread for the default session, or
20/// a dedicated OS thread for sessions created by `spawn_dedicated_thread`.
21#[derive(Clone)]
22pub struct LocalExecutor {
23    session_id: SessionId,
24    scheduler: Arc<dyn Scheduler>,
25    // Spawned tasks' schedule callbacks each hold an `Arc` clone of this
26    // closure, so the destination it captures stays alive as long as work
27    // could still land on it.
28    dispatch: Arc<dyn Fn(Runnable<RunnableMeta>) + Send + Sync>,
29    not_send: PhantomData<Rc<()>>,
30}
31
32impl LocalExecutor {
33    /// Constructs a local executor that runs spawned tasks by sending their
34    /// runnables through `dispatch`. The `scheduler` is retained for access to
35    /// clocks, timers, and other scheduler-level services.
36    ///
37    /// For the common case of routing runnables through
38    /// `Scheduler::schedule_local`, callers pass a closure that does exactly
39    /// that. `spawn_dedicated_thread` instead passes a closure that sends to
40    /// the dedicated thread's channel.
41    pub fn new(
42        session_id: SessionId,
43        scheduler: Arc<dyn Scheduler>,
44        dispatch: impl Fn(Runnable<RunnableMeta>) + Send + Sync + 'static,
45    ) -> Self {
46        Self {
47            session_id,
48            scheduler,
49            dispatch: Arc::new(dispatch),
50            not_send: PhantomData,
51        }
52    }
53
54    pub fn session_id(&self) -> SessionId {
55        self.session_id
56    }
57
58    pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
59        &self.scheduler
60    }
61
62    #[track_caller]
63    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
64    where
65        F: Future + 'static,
66        F::Output: 'static,
67    {
68        let dispatch = self.dispatch.clone();
69        let location = Location::caller();
70        let (runnable, task) = spawn_local_with_source_location(
71            future,
72            move |runnable| dispatch(runnable),
73            RunnableMeta {
74                location,
75                spawned: crate::SpawnTime(Instant::now()),
76            },
77        );
78        runnable.schedule();
79        Task(TaskState::Spawned(task))
80    }
81
82    pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
83        use std::cell::Cell;
84
85        let output = Cell::new(None);
86        let future = async {
87            output.set(Some(future.await));
88        };
89        let mut future = std::pin::pin!(future);
90
91        self.scheduler
92            .block(Some(self.session_id), future.as_mut(), None);
93
94        output.take().expect("block_on future did not complete")
95    }
96
97    /// Block until the future completes or timeout occurs.
98    /// Returns Ok(output) if completed, Err(future) if timed out.
99    pub fn block_with_timeout<Fut: Future>(
100        &self,
101        timeout: Duration,
102        future: Fut,
103    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
104        use std::cell::Cell;
105
106        let output = Cell::new(None);
107        let mut future = Box::pin(future);
108
109        {
110            let future_ref = &mut future;
111            let wrapper = async {
112                output.set(Some(future_ref.await));
113            };
114            let mut wrapper = std::pin::pin!(wrapper);
115
116            self.scheduler
117                .block(Some(self.session_id), wrapper.as_mut(), Some(timeout));
118        }
119
120        match output.take() {
121            Some(value) => Ok(value),
122            None => Err(future),
123        }
124    }
125
126    #[track_caller]
127    pub fn timer(&self, duration: Duration) -> Timer {
128        self.scheduler.timer(duration)
129    }
130
131    pub fn now(&self) -> Instant {
132        self.scheduler.clock().now()
133    }
134
135    /// Spawn a closure on a fresh session pinned to its own [`LocalExecutor`].
136    /// The closure runs on a new OS thread under `PlatformScheduler`, or on
137    /// the test scheduler's loop under `TestScheduler`.
138    ///
139    /// The returned `Task` represents the dedicated work: dropping it cancels
140    /// the dedicated closure, `.await`ing it yields the closure's return
141    /// value, `.detach()`ing it lets the dedicated work run independently of
142    /// the caller.
143    #[track_caller]
144    pub fn spawn_dedicated<F, Fut>(&self, f: F) -> Task<Fut::Output>
145    where
146        F: FnOnce(LocalExecutor) -> Fut + Send + 'static,
147        Fut: Future + 'static,
148        Fut::Output: Send + Sync + 'static,
149    {
150        self.scheduler
151            .clone()
152            .spawn_dedicated(box_dedicated(f))
153            .downcast::<Fut::Output>()
154    }
155}
156
157/// Boxes the user-supplied dedicated closure into the type-erased shape
158/// expected by [`Scheduler::spawn_dedicated`]. The user's `Fut::Output` is
159/// boxed as `Box<dyn Any + Send + Sync>` on the dedicated side and downcast
160/// back to `Fut::Output` by [`Task::downcast`] in the wrapper.
161fn box_dedicated<F, Fut>(
162    f: F,
163) -> Box<
164    dyn FnOnce(LocalExecutor) -> Pin<Box<dyn Future<Output = Box<dyn Any + Send + Sync>> + 'static>>
165        + Send
166        + 'static,
167>
168where
169    F: FnOnce(LocalExecutor) -> Fut + Send + 'static,
170    Fut: Future + 'static,
171    Fut::Output: Send + Sync + 'static,
172{
173    Box::new(move |executor| {
174        Box::pin(async move { Box::new(f(executor).await) as Box<dyn Any + Send + Sync> })
175    })
176}
177
178#[derive(Clone)]
179pub struct BackgroundExecutor {
180    scheduler: Arc<dyn Scheduler>,
181}
182
183impl BackgroundExecutor {
184    pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
185        Self { scheduler }
186    }
187
188    #[track_caller]
189    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
190    where
191        F: Future + Send + 'static,
192        F::Output: Send + 'static,
193    {
194        self.spawn_with_priority(Priority::default(), future)
195    }
196
197    #[track_caller]
198    pub fn spawn_with_priority<F>(&self, priority: Priority, future: F) -> Task<F::Output>
199    where
200        F: Future + Send + 'static,
201        F::Output: Send + 'static,
202    {
203        let scheduler = Arc::downgrade(&self.scheduler);
204        let location = Location::caller();
205        let (runnable, task) = async_task::Builder::new()
206            .metadata(RunnableMeta {
207                location,
208                spawned: crate::SpawnTime(Instant::now()),
209            })
210            .spawn(
211                move |_| future,
212                move |runnable| {
213                    if let Some(scheduler) = scheduler.upgrade() {
214                        scheduler.schedule_background_with_priority(runnable, priority);
215                    }
216                },
217            );
218        runnable.schedule();
219        Task(TaskState::Spawned(task))
220    }
221
222    /// Spawns a future on a dedicated realtime thread for audio processing.
223    #[track_caller]
224    pub fn spawn_realtime<F>(&self, future: F) -> Task<F::Output>
225    where
226        F: Future + Send + 'static,
227        F::Output: Send + 'static,
228    {
229        let location = Location::caller();
230        let (tx, rx) = flume::bounded::<async_task::Runnable<RunnableMeta>>(1);
231
232        self.scheduler.spawn_realtime(Box::new(move || {
233            while let Ok(runnable) = rx.recv() {
234                runnable.run();
235            }
236        }));
237
238        let (runnable, task) = async_task::Builder::new()
239            .metadata(RunnableMeta {
240                location,
241                spawned: crate::SpawnTime(Instant::now()),
242            })
243            .spawn(
244                move |_| future,
245                move |runnable| {
246                    let _ = tx.send(runnable);
247                },
248            );
249        runnable.schedule();
250        Task(TaskState::Spawned(task))
251    }
252
253    #[track_caller]
254    pub fn timer(&self, duration: Duration) -> Timer {
255        self.scheduler.timer(duration)
256    }
257
258    pub fn now(&self) -> Instant {
259        self.scheduler.clock().now()
260    }
261
262    pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
263        &self.scheduler
264    }
265
266    /// Spawn a closure on a fresh session pinned to its own [`LocalExecutor`].
267    /// The closure runs on a new OS thread under `PlatformScheduler`, or on
268    /// the test scheduler's loop under `TestScheduler`.
269    ///
270    /// The returned `Task` represents the dedicated work: dropping it cancels
271    /// the dedicated closure, `.await`ing it yields the closure's return
272    /// value, `.detach()`ing it lets the dedicated work run independently of
273    /// the caller.
274    #[track_caller]
275    pub fn spawn_dedicated<F, Fut>(&self, f: F) -> Task<Fut::Output>
276    where
277        F: FnOnce(LocalExecutor) -> Fut + Send + 'static,
278        Fut: Future + 'static,
279        Fut::Output: Send + Sync + 'static,
280    {
281        self.scheduler
282            .clone()
283            .spawn_dedicated(box_dedicated(f))
284            .downcast::<Fut::Output>()
285    }
286}
287
288/// Task is a primitive that allows work to happen in the background.
289///
290/// It implements [`Future`] so you can `.await` on it.
291///
292/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
293/// the task to continue running, but with no way to return a value.
294#[must_use]
295pub struct Task<T>(TaskState<T>);
296
297enum TaskState<T> {
298    /// A task that is ready to return a value
299    Ready(Option<T>),
300
301    /// A task that is currently running.
302    Spawned(async_task::Task<T, RunnableMeta>),
303
304    /// A typed view of a [`Task<Box<dyn Any + Send + Sync>>`] obtained via
305    /// [`Task::downcast`]. The inner task drives the actual work; the
306    /// downcast layer just unwraps the `Box<dyn Any + Send + Sync>` on poll.
307    Downcast {
308        inner: Box<Task<Box<dyn Any + Send + Sync>>>,
309        marker: PhantomData<fn() -> T>,
310    },
311}
312
313impl<T> Task<T> {
314    /// Creates a new task that will resolve with the value
315    pub fn ready(val: T) -> Self {
316        Task(TaskState::Ready(Some(val)))
317    }
318
319    /// Creates a Task from an async_task::Task
320    pub fn from_async_task(task: async_task::Task<T, RunnableMeta>) -> Self {
321        Task(TaskState::Spawned(task))
322    }
323
324    pub fn is_ready(&self) -> bool {
325        match &self.0 {
326            TaskState::Ready(_) => true,
327            TaskState::Spawned(task) => task.is_finished(),
328            TaskState::Downcast { inner, .. } => inner.is_ready(),
329        }
330    }
331
332    /// Detaching a task runs it to completion in the background
333    pub fn detach(self) {
334        match self {
335            Task(TaskState::Ready(_)) => {}
336            Task(TaskState::Spawned(task)) => task.detach(),
337            Task(TaskState::Downcast { inner, .. }) => inner.detach(),
338        }
339    }
340
341    /// Converts this task into a fallible task that returns `Option<T>`.
342    pub fn fallible(self) -> FallibleTask<T> {
343        FallibleTask(match self.0 {
344            TaskState::Ready(val) => FallibleTaskState::Ready(val),
345            TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
346            TaskState::Downcast { inner, .. } => FallibleTaskState::Downcast {
347                inner: Box::new(inner.fallible()),
348                marker: PhantomData,
349            },
350        })
351    }
352}
353
354impl Task<Box<dyn Any + Send + Sync>> {
355    /// Reinterprets the boxed output as a concrete `T` via downcast on
356    /// completion. Used by [`LocalExecutor::spawn_dedicated`] and
357    /// [`BackgroundExecutor::spawn_dedicated`] to recover the user closure's
358    /// `Fut::Output` from the dyn-safe [`Scheduler::spawn_dedicated`].
359    ///
360    /// Panics on poll if the inner output is not in fact a `T` -- a logic
361    /// error in whatever produced the inner task, since the downcast type is
362    /// chosen by the caller of `downcast`.
363    pub fn downcast<T: Send + Sync + 'static>(self) -> Task<T> {
364        Task(TaskState::Downcast {
365            inner: Box::new(self),
366            marker: PhantomData,
367        })
368    }
369}
370
371impl<T> std::fmt::Debug for Task<T> {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        match &self.0 {
374            TaskState::Ready(_) => f.debug_tuple("Task::Ready").finish(),
375            TaskState::Spawned(task) => f.debug_tuple("Task::Spawned").field(task).finish(),
376            TaskState::Downcast { inner, .. } => {
377                f.debug_tuple("Task::Downcast").field(inner).finish()
378            }
379        }
380    }
381}
382
383/// A task that returns `Option<T>` instead of panicking when cancelled.
384#[must_use]
385pub struct FallibleTask<T>(FallibleTaskState<T>);
386
387enum FallibleTaskState<T> {
388    /// A task that is ready to return a value
389    Ready(Option<T>),
390
391    /// A task that is currently running (wraps async_task::FallibleTask).
392    Spawned(async_task::FallibleTask<T, RunnableMeta>),
393
394    /// Mirror of [`TaskState::Downcast`] for fallible tasks.
395    Downcast {
396        inner: Box<FallibleTask<Box<dyn Any + Send + Sync>>>,
397        marker: PhantomData<fn() -> T>,
398    },
399}
400
401impl<T> FallibleTask<T> {
402    /// Creates a new fallible task that will resolve with the value.
403    pub fn ready(val: T) -> Self {
404        FallibleTask(FallibleTaskState::Ready(Some(val)))
405    }
406
407    /// Detaching a task runs it to completion in the background.
408    pub fn detach(self) {
409        match self.0 {
410            FallibleTaskState::Ready(_) => {}
411            FallibleTaskState::Spawned(task) => task.detach(),
412            FallibleTaskState::Downcast { inner, .. } => inner.detach(),
413        }
414    }
415}
416
417impl<T: 'static> Future for FallibleTask<T> {
418    type Output = Option<T>;
419
420    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
421        match unsafe { self.get_unchecked_mut() } {
422            FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
423            FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
424            FallibleTask(FallibleTaskState::Downcast { inner, .. }) => {
425                match Pin::new(inner.as_mut()).poll(cx) {
426                    Poll::Ready(Some(boxed_any)) => Poll::Ready(Some(
427                        *boxed_any
428                            .downcast::<T>()
429                            .expect("FallibleTask::poll: downcast type mismatch"),
430                    )),
431                    Poll::Ready(None) => Poll::Ready(None),
432                    Poll::Pending => Poll::Pending,
433                }
434            }
435        }
436    }
437}
438
439impl<T> std::fmt::Debug for FallibleTask<T> {
440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441        match &self.0 {
442            FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
443            FallibleTaskState::Spawned(task) => {
444                f.debug_tuple("FallibleTask::Spawned").field(task).finish()
445            }
446            FallibleTaskState::Downcast { inner, .. } => f
447                .debug_tuple("FallibleTask::Downcast")
448                .field(inner)
449                .finish(),
450        }
451    }
452}
453
454impl<T: 'static> Future for Task<T> {
455    type Output = T;
456
457    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
458        match unsafe { self.get_unchecked_mut() } {
459            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
460            Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
461            Task(TaskState::Downcast { inner, .. }) => match Pin::new(inner.as_mut()).poll(cx) {
462                Poll::Ready(boxed_any) => Poll::Ready(
463                    *boxed_any
464                        .downcast::<T>()
465                        .expect("Task::poll: downcast type mismatch"),
466                ),
467                Poll::Pending => Poll::Pending,
468            },
469        }
470    }
471}
472
473/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
474#[track_caller]
475fn spawn_local_with_source_location<Fut, S>(
476    future: Fut,
477    schedule: S,
478    metadata: RunnableMeta,
479) -> (
480    async_task::Runnable<RunnableMeta>,
481    async_task::Task<Fut::Output, RunnableMeta>,
482)
483where
484    Fut: Future + 'static,
485    Fut::Output: 'static,
486    S: async_task::Schedule<RunnableMeta> + Send + Sync + 'static,
487{
488    #[inline]
489    fn thread_id() -> ThreadId {
490        std::thread_local! {
491            static ID: ThreadId = thread::current().id();
492        }
493        ID.try_with(|id| *id)
494            .unwrap_or_else(|_| thread::current().id())
495    }
496
497    struct Checked<F> {
498        id: ThreadId,
499        inner: ManuallyDrop<F>,
500        location: &'static Location<'static>,
501    }
502
503    impl<F> Drop for Checked<F> {
504        fn drop(&mut self) {
505            assert_eq!(
506                self.id,
507                thread_id(),
508                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
509                self.location
510            );
511            // SAFETY: `inner` is wrapped in `ManuallyDrop`, so this is the only
512            // place it is dropped. The thread check above ensures local futures
513            // are dropped on the thread that created them.
514            unsafe { ManuallyDrop::drop(&mut self.inner) };
515        }
516    }
517
518    impl<F: Future> Future for Checked<F> {
519        type Output = F::Output;
520
521        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
522            // SAFETY: We don't move any fields out of `self`; this mutable
523            // reference is only used to check metadata and to project the pin to
524            // `inner` below.
525            let this = unsafe { self.get_unchecked_mut() };
526            assert!(
527                this.id == thread_id(),
528                "local task polled by a thread that didn't spawn it. Task spawned at {}",
529                this.location
530            );
531            // SAFETY: `inner` is structurally pinned by `Checked`; after
532            // `Checked` is pinned, `inner` is never moved. The thread check
533            // above ensures the local future is only polled by its spawning
534            // thread.
535            unsafe { Pin::new_unchecked(&mut *this.inner).poll(cx) }
536        }
537    }
538
539    let location = metadata.location;
540
541    let future = move |_| Checked {
542        id: thread_id(),
543        inner: ManuallyDrop::new(future),
544        location,
545    };
546
547    let builder = async_task::Builder::new().metadata(metadata);
548    // SAFETY: `Checked` enforces the invariants required by `spawn_unchecked`:
549    // the non-`Send` future is only polled and dropped on the thread that
550    // spawned it.
551    unsafe { builder.spawn_unchecked(future, schedule) }
552}