poll_promise/
promise.rs

1use std::cell::UnsafeCell;
2
3/// Used to send a result to a [`Promise`].
4///
5/// You must call [`Self::send`] with a value eventually.
6///
7/// If you drop the `Sender` without putting a value into it,
8/// it will cause the connected [`Promise`] to panic when polled.
9#[must_use = "You should call Sender::send with the result"]
10pub struct Sender<T>(std::sync::mpsc::Sender<T>);
11
12impl<T> Sender<T> {
13    /// Send the result to the [`Promise`].
14    ///
15    /// If the [`Promise`] has dropped, this does nothing.
16    pub fn send(self, value: T) {
17        self.0.send(value).ok(); // We ignore the error caused by the receiver being dropped.
18    }
19}
20
21/// The type of a running task.
22#[derive(Clone, Copy)]
23#[allow(dead_code)]
24pub enum TaskType {
25    /// This task is running in the local thread.
26    Local,
27    /// This task is running async in another thread.
28    Async,
29    /// This task is running in a different manner.
30    None,
31}
32
33// ----------------------------------------------------------------------------
34
35/// A promise that waits for the reception of a single value,
36/// presumably from some async task.
37///
38/// A `Promise` starts out waiting for a value.
39/// Each time you call a member method it will check if that value is ready.
40/// Once ready, the `Promise` will store the value until you drop the `Promise`.
41///
42/// Example:
43///
44/// ```
45/// # fn something_slow() {}
46/// # use poll_promise::Promise;
47/// #
48/// let promise = Promise::spawn_thread("slow_operation", move || something_slow());
49///
50/// // Then in the game loop or immediate mode GUI code:
51/// if let Some(result) = promise.ready() {
52///     // Use/show result
53/// } else {
54///     // Show a loading screen
55/// }
56/// ```
57///
58/// If you enable the `tokio` feature you can use `poll-promise` with the [tokio](https://github.com/tokio-rs/tokio)
59/// runtime to run `async` tasks using [`Promise::spawn_async`], [`Promise::spawn_local`], and [`Promise::spawn_blocking`].
60#[must_use]
61pub struct Promise<T: Send + 'static> {
62    data: PromiseImpl<T>,
63    task_type: TaskType,
64
65    #[cfg(feature = "tokio")]
66    join_handle: Option<tokio::task::JoinHandle<()>>,
67
68    #[cfg(feature = "smol")]
69    smol_task: Option<smol::Task<()>>,
70
71    #[cfg(feature = "async-std")]
72    async_std_join_handle: Option<async_std::task::JoinHandle<()>>,
73}
74
75#[cfg(all(
76    not(docsrs),
77    any(
78        all(feature = "tokio", feature = "smol"),
79        all(feature = "tokio", feature = "async-std"),
80        all(feature = "tokio", feature = "web"),
81        all(feature = "smol", feature = "async-std"),
82        all(feature = "smol", feature = "web"),
83        all(feature = "async-std", feature = "web"),
84    )
85))]
86compile_error!(
87    "You can only specify one of the executor features: 'tokio', 'smol', 'async-std' or 'web'"
88);
89
90// Ensure that Promise is !Sync, confirming the safety of the unsafe code.
91static_assertions::assert_not_impl_all!(Promise<u32>: Sync);
92static_assertions::assert_impl_all!(Promise<u32>: Send);
93
94impl<T: Send + 'static> Promise<T> {
95    /// Create a [`Promise`] and a corresponding [`Sender`].
96    ///
97    /// Put the promised value into the sender when it is ready.
98    /// If you drop the `Sender` without putting a value into it,
99    /// it will cause a panic when polling the `Promise`.
100    ///
101    /// See also [`Self::spawn_blocking`], [`Self::spawn_async`], [`Self::spawn_local`], and [`Self::spawn_thread`].
102    pub fn new() -> (Sender<T>, Self) {
103        // We need a channel that we can wait blocking on (for `Self::block_until_ready`).
104        // (`tokio::sync::oneshot` does not support blocking receive).
105        let (tx, rx) = std::sync::mpsc::channel();
106        (
107            Sender(tx),
108            Self {
109                data: PromiseImpl(UnsafeCell::new(PromiseStatus::Pending(rx))),
110                task_type: TaskType::None,
111
112                #[cfg(feature = "tokio")]
113                join_handle: None,
114
115                #[cfg(feature = "async-std")]
116                async_std_join_handle: None,
117
118                #[cfg(feature = "smol")]
119                smol_task: None,
120            },
121        )
122    }
123
124    /// Create a promise that already has the result.
125    pub fn from_ready(value: T) -> Self {
126        Self {
127            data: PromiseImpl(UnsafeCell::new(PromiseStatus::Ready(value))),
128            task_type: TaskType::None,
129
130            #[cfg(feature = "tokio")]
131            join_handle: None,
132
133            #[cfg(feature = "async-std")]
134            async_std_join_handle: None,
135
136            #[cfg(feature = "smol")]
137            smol_task: None,
138        }
139    }
140
141    /// Spawn a future. Runs the task concurrently.
142    ///
143    /// See [`Self::spawn_local`].
144    ///
145    /// You need to compile `poll-promise` with the "tokio" feature for this to be available.
146    ///
147    /// ## tokio
148    /// This should be used for spawning asynchronous work that does _not_ do any heavy CPU computations
149    /// as that will block other spawned tasks and will delay them. For example network IO, timers, etc.
150    ///
151    /// These type of future can have manually blocking code within it though, but has to then manually use
152    /// [`tokio::task::block_in_place`](https://docs.rs/tokio/1.15.0/tokio/task/fn.block_in_place.html) on that,
153    /// or `.await` that future.
154    ///
155    /// If you have a function or closure that you just want to offload to processed in the background, use the [`Self::spawn_blocking`] function instead.
156    ///
157    /// See the [tokio docs](https://docs.rs/tokio/1.15.0/tokio/index.html#cpu-bound-tasks-and-blocking-code) for more details about
158    /// CPU-bound tasks vs async IO tasks.
159    ///
160    /// This is a convenience method, using [`Self::new`] with [`tokio::task::spawn`].
161    ///
162    /// ## Example
163    /// ``` no_run
164    /// # async fn something_async() {}
165    /// # use poll_promise::Promise;
166    /// let promise = Promise::spawn_async(async move { something_async().await });
167    /// ```
168    #[cfg(any(feature = "tokio", feature = "smol", feature = "async-std"))]
169    pub fn spawn_async(future: impl std::future::Future<Output = T> + 'static + Send) -> Self {
170        let (sender, mut promise) = Self::new();
171        promise.task_type = TaskType::Async;
172
173        #[cfg(feature = "tokio")]
174        {
175            promise.join_handle =
176                Some(tokio::task::spawn(async move { sender.send(future.await) }));
177        }
178
179        #[cfg(feature = "smol")]
180        {
181            promise.smol_task =
182                Some(crate::EXECUTOR.spawn(async move { sender.send(future.await) }));
183        }
184
185        #[cfg(feature = "async-std")]
186        {
187            promise.async_std_join_handle =
188                Some(async_std::task::spawn(
189                    async move { sender.send(future.await) },
190                ));
191        }
192
193        promise
194    }
195
196    /// Spawn a future. Runs it in the local thread.
197    ///
198    /// You need to compile `poll-promise` with either the "tokio", "smol", or "web" feature for this to be available.
199    ///
200    /// This is a convenience method, using [`Self::new`] with [`tokio::task::spawn_local`].
201    /// Unlike [`Self::spawn_async`] this method does not require [`Send`].
202    /// However, you will have to set up [`tokio::task::LocalSet`]s yourself.
203    ///
204    /// ## Example
205    /// ``` no_run
206    /// # async fn something_async() {}
207    /// # use poll_promise::Promise;
208    /// let promise = Promise::spawn_local(async move { something_async().await });
209    /// ```
210    #[cfg(any(feature = "tokio", feature = "web", feature = "smol"))]
211    pub fn spawn_local(future: impl std::future::Future<Output = T> + 'static) -> Self {
212        // When using the web feature we don't mutate promise.
213        #[allow(unused_mut)]
214        let (sender, mut promise) = Self::new();
215        promise.task_type = TaskType::Local;
216
217        // This *generally* works but not super well.
218        // Tokio doesn't do any fancy local scheduling.
219        #[cfg(feature = "tokio")]
220        {
221            promise.join_handle = Some(tokio::task::spawn_local(async move {
222                sender.send(future.await);
223            }));
224        }
225
226        #[cfg(feature = "web")]
227        {
228            wasm_bindgen_futures::spawn_local(async move { sender.send(future.await) });
229        }
230
231        #[cfg(feature = "smol")]
232        {
233            promise.smol_task = Some(
234                crate::LOCAL_EXECUTOR
235                    .with(|exec| exec.spawn(async move { sender.send(future.await) })),
236            );
237        }
238
239        promise
240    }
241
242    /// Spawn a blocking closure in a background task.
243    ///
244    /// You need to compile `poll-promise` with the "tokio" feature for this to be available.
245    ///
246    /// ## tokio
247    /// This is a simple mechanism to offload a heavy function/closure to be processed in the thread pool for blocking CPU work.
248    ///
249    /// It can't do any async code. For that, use [`Self::spawn_async`].
250    ///
251    /// This is a convenience method, using [`Self::new`] with [`tokio::task::spawn`] and [`tokio::task::block_in_place`].
252    ///
253    /// ``` no_run
254    /// # fn something_cpu_intensive() {}
255    /// # use poll_promise::Promise;
256    /// let promise = Promise::spawn_blocking(move || something_cpu_intensive());
257    /// ```
258    #[cfg(any(feature = "tokio", feature = "async-std"))]
259    pub fn spawn_blocking<F>(f: F) -> Self
260    where
261        F: FnOnce() -> T + Send + 'static,
262    {
263        let (sender, mut promise) = Self::new();
264        #[cfg(feature = "tokio")]
265        {
266            promise.join_handle = Some(tokio::task::spawn(async move {
267                sender.send(tokio::task::block_in_place(f));
268            }));
269        }
270
271        #[cfg(feature = "async-std")]
272        {
273            promise.async_std_join_handle = Some(async_std::task::spawn_blocking(move || {
274                sender.send(f());
275            }));
276        }
277
278        promise
279    }
280
281    /// Spawn a blocking closure in a background thread.
282    ///
283    /// The first argument is the name of the thread you spawn, passed to [`std::thread::Builder::name`].
284    /// It shows up in panic messages.
285    ///
286    /// This is a convenience method, using [`Self::new`] and [`std::thread::Builder`].
287    ///
288    /// If you are compiling with the "tokio" or "web" features, you should use [`Self::spawn_blocking`] or [`Self::spawn_async`] instead.
289    ///
290    /// ```
291    /// # fn something_slow() {}
292    /// # use poll_promise::Promise;
293    /// let promise = Promise::spawn_thread("slow_operation", move || something_slow());
294    /// ```
295    #[cfg(not(target_arch = "wasm32"))] // can't spawn threads in wasm.
296    pub fn spawn_thread<F>(thread_name: impl Into<String>, f: F) -> Self
297    where
298        F: FnOnce() -> T + Send + 'static,
299    {
300        let (sender, promise) = Self::new();
301        std::thread::Builder::new()
302            .name(thread_name.into())
303            .spawn(move || sender.send(f()))
304            .expect("Failed to spawn thread");
305        promise
306    }
307
308    /// Polls the promise and either returns a reference to the data, or [`None`] if still pending.
309    ///
310    /// Panics if the connected [`Sender`] was dropped before a value was sent.
311    pub fn ready(&self) -> Option<&T> {
312        match self.poll() {
313            std::task::Poll::Pending => None,
314            std::task::Poll::Ready(value) => Some(value),
315        }
316    }
317
318    /// Polls the promise and either returns a mutable reference to the data, or [`None`] if still pending.
319    ///
320    /// Panics if the connected [`Sender`] was dropped before a value was sent.
321    pub fn ready_mut(&mut self) -> Option<&mut T> {
322        match self.poll_mut() {
323            std::task::Poll::Pending => None,
324            std::task::Poll::Ready(value) => Some(value),
325        }
326    }
327
328    /// Returns either the completed promise object or the promise itself if it is not completed yet.
329    ///
330    /// Panics if the connected [`Sender`] was dropped before a value was sent.
331    pub fn try_take(self) -> Result<T, Self> {
332        self.data.try_take().map_err(|data| Self {
333            data,
334            task_type: self.task_type,
335
336            #[cfg(feature = "tokio")]
337            join_handle: None,
338
339            #[cfg(feature = "async-std")]
340            async_std_join_handle: None,
341
342            #[cfg(feature = "smol")]
343            smol_task: self.smol_task,
344        })
345    }
346
347    /// Block execution until ready, then returns a reference to the value.
348    ///
349    /// Panics if the connected [`Sender`] was dropped before a value was sent.
350    pub fn block_until_ready(&self) -> &T {
351        self.data.block_until_ready(self.task_type)
352    }
353
354    /// Block execution until ready, then returns a mutable reference to the value.
355    ///
356    /// Panics if the connected [`Sender`] was dropped before a value was sent.
357    pub fn block_until_ready_mut(&mut self) -> &mut T {
358        self.data.block_until_ready_mut(self.task_type)
359    }
360
361    /// Block execution until ready, then returns the promised value and consumes the `Promise`.
362    ///
363    /// Panics if the connected [`Sender`] was dropped before a value was sent.
364    pub fn block_and_take(self) -> T {
365        self.data.block_until_ready(self.task_type);
366        match self.data.0.into_inner() {
367            PromiseStatus::Pending(_) => unreachable!(),
368            PromiseStatus::Ready(value) => value,
369        }
370    }
371
372    /// Returns either a reference to the ready value [`std::task::Poll::Ready`]
373    /// or [`std::task::Poll::Pending`].
374    ///
375    /// Panics if the connected [`Sender`] was dropped before a value was sent.
376    pub fn poll(&self) -> std::task::Poll<&T> {
377        self.data.poll(self.task_type)
378    }
379
380    /// Returns either a mut reference to the ready value in a [`std::task::Poll::Ready`]
381    /// or a [`std::task::Poll::Pending`].
382    ///
383    /// Panics if the connected [`Sender`] was dropped before a value was sent.
384    pub fn poll_mut(&mut self) -> std::task::Poll<&mut T> {
385        self.data.poll_mut(self.task_type)
386    }
387
388    /// Returns the type of task this promise is running.
389    /// See [`TaskType`].
390    pub fn task_type(&self) -> TaskType {
391        self.task_type
392    }
393
394    /// Abort the running task spawned by [`Self::spawn_async`].
395    #[cfg(feature = "tokio")]
396    pub fn abort(self) {
397        if let Some(join_handle) = self.join_handle {
398            join_handle.abort();
399        }
400    }
401}
402
403// ----------------------------------------------------------------------------
404
405enum PromiseStatus<T: Send + 'static> {
406    Pending(std::sync::mpsc::Receiver<T>),
407    Ready(T),
408}
409
410struct PromiseImpl<T: Send + 'static>(UnsafeCell<PromiseStatus<T>>);
411
412impl<T: Send + 'static> PromiseImpl<T> {
413    #[allow(unused_variables)]
414    fn poll_mut(&mut self, task_type: TaskType) -> std::task::Poll<&mut T> {
415        let inner = self.0.get_mut();
416        match inner {
417            PromiseStatus::Pending(rx) => {
418                #[cfg(all(feature = "smol", feature = "smol_tick_poll"))]
419                Self::tick(task_type);
420                if let Ok(value) = rx.try_recv() {
421                    *inner = PromiseStatus::Ready(value);
422                    match inner {
423                        PromiseStatus::Ready(ref mut value) => std::task::Poll::Ready(value),
424                        PromiseStatus::Pending(_) => unreachable!(),
425                    }
426                } else {
427                    std::task::Poll::Pending
428                }
429            }
430            PromiseStatus::Ready(ref mut value) => std::task::Poll::Ready(value),
431        }
432    }
433
434    /// Returns either the completed promise object or the promise itself if it is not completed yet.
435    fn try_take(self) -> Result<T, Self> {
436        let inner = self.0.into_inner();
437        match inner {
438            PromiseStatus::Pending(ref rx) => match rx.try_recv() {
439                Ok(value) => Ok(value),
440                Err(std::sync::mpsc::TryRecvError::Empty) => {
441                    Err(PromiseImpl(UnsafeCell::new(inner)))
442                }
443                Err(std::sync::mpsc::TryRecvError::Disconnected) => {
444                    panic!("The Promise Sender was dropped")
445                }
446            },
447            PromiseStatus::Ready(value) => Ok(value),
448        }
449    }
450
451    #[allow(unsafe_code)]
452    #[allow(unused_variables)]
453    fn poll(&self, task_type: TaskType) -> std::task::Poll<&T> {
454        let this = unsafe {
455            // SAFETY: This is safe since Promise (and PromiseData) are !Sync and thus
456            // need external synchronization anyway. We can only transition from
457            // Pending->Ready, not the other way around, so once we're Ready we'll
458            // stay ready.
459            self.0.get().as_mut().expect("UnsafeCell should be valid")
460        };
461        match this {
462            PromiseStatus::Pending(rx) => {
463                #[cfg(all(feature = "smol", feature = "smol_tick_poll"))]
464                Self::tick(task_type);
465                match rx.try_recv() {
466                    Ok(value) => {
467                        *this = PromiseStatus::Ready(value);
468                        match this {
469                            PromiseStatus::Ready(ref value) => std::task::Poll::Ready(value),
470                            PromiseStatus::Pending(_) => unreachable!(),
471                        }
472                    }
473                    Err(std::sync::mpsc::TryRecvError::Empty) => std::task::Poll::Pending,
474                    Err(std::sync::mpsc::TryRecvError::Disconnected) => {
475                        panic!("The Promise Sender was dropped")
476                    }
477                }
478            }
479            PromiseStatus::Ready(ref value) => std::task::Poll::Ready(value),
480        }
481    }
482
483    #[allow(unused_variables)]
484    fn block_until_ready_mut(&mut self, task_type: TaskType) -> &mut T {
485        // Constantly poll until we're ready.
486        #[cfg(feature = "smol")]
487        while self.poll(task_type).is_pending() {
488            // Tick unless poll does it for us.
489            #[cfg(not(feature = "smol_tick_poll"))]
490            Self::tick(task_type);
491        }
492        let inner = self.0.get_mut();
493        match inner {
494            PromiseStatus::Pending(rx) => {
495                let value = rx.recv().expect("The Promise Sender was dropped");
496                *inner = PromiseStatus::Ready(value);
497                match inner {
498                    PromiseStatus::Ready(ref mut value) => value,
499                    PromiseStatus::Pending(_) => unreachable!(),
500                }
501            }
502            PromiseStatus::Ready(ref mut value) => value,
503        }
504    }
505
506    #[allow(unsafe_code)]
507    #[allow(unused_variables)]
508    fn block_until_ready(&self, task_type: TaskType) -> &T {
509        // Constantly poll until we're ready.
510        #[cfg(feature = "smol")]
511        while self.poll(task_type).is_pending() {
512            // Tick unless poll does it for us.
513            #[cfg(not(feature = "smol_tick_poll"))]
514            Self::tick(task_type);
515        }
516        let this = unsafe {
517            // SAFETY: This is safe since Promise (and PromiseData) are !Sync and thus
518            // need external synchronization anyway. We can only transition from
519            // Pending->Ready, not the other way around, so once we're Ready we'll
520            // stay ready.
521            self.0.get().as_mut().expect("UnsafeCell should be valid")
522        };
523        match this {
524            PromiseStatus::Pending(rx) => {
525                let value = rx.recv().expect("The Promise Sender was dropped");
526                *this = PromiseStatus::Ready(value);
527                match this {
528                    PromiseStatus::Ready(ref value) => value,
529                    PromiseStatus::Pending(_) => unreachable!(),
530                }
531            }
532            PromiseStatus::Ready(ref value) => value,
533        }
534    }
535
536    #[cfg(feature = "smol")]
537    fn tick(task_type: TaskType) {
538        match task_type {
539            TaskType::Local => {
540                crate::tick_local();
541            }
542            TaskType::Async => {
543                crate::tick();
544            }
545            TaskType::None => (),
546        };
547    }
548}