async_promise/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", env!("CARGO_PKG_README")))]
3
4use std::cell::UnsafeCell;
5use std::mem::MaybeUninit;
6use std::sync::atomic::{AtomicU8, Ordering};
7use std::sync::{Arc, Mutex, Weak};
8use std::task::{Context, Poll, Waker};
9
10/// Creates a new promise and resolver pair.
11pub fn channel<T>() -> (Resolve<T>, Promise<T>) {
12    let inner = Arc::new(Inner::new());
13    (
14        Resolve {
15            inner: Some(Arc::downgrade(&inner)),
16        },
17        Promise { inner },
18    )
19}
20
21/// A container for a value of type `T` that may not yet be resolved.
22///
23/// Multiple consumers may obtain the value via shared reference (`&`) once resolved.
24pub struct Promise<T> {
25    inner: Arc<Inner<T>>,
26}
27impl<T> Promise<T> {
28    /// Waits for the promise to be resolved, returning the value once set.
29    ///
30    /// Returns `None` if the resolver was dropped before sending a value.
31    pub async fn wait(&self) -> Option<&T> {
32        std::future::poll_fn(|cx| self.inner.poll_get(cx)).await
33    }
34
35    /// Attempts to get the value if it has been resolved, returning an error if not.
36    ///
37    /// Returns [`PromiseError::Empty`] if the value has not yet been set, and [`PromiseError::Dropped`] if the resolver
38    /// was dropped before sending a value.
39    pub fn try_get(&self) -> Result<&T, PromiseError> {
40        self.inner
41            .get()
42            .ok_or(PromiseError::Empty)
43            .and_then(|value_opt| value_opt.ok_or(PromiseError::Dropped))
44    }
45
46    /// Returns if the `Resolver` has been resolved OR dropped.
47    ///
48    /// * `true` indicates the `Promise` is complete, either by resolving or dropping.
49    /// * `false` indicates the value may still resolve in the future.
50    pub fn is_done(&self) -> bool {
51        self.inner.get().is_some()
52    }
53}
54
55/// An error that may occur when trying to get the value from a [`Promise`], see [`Promise::try_get`].
56#[derive(Debug, Eq, PartialEq, Clone, thiserror::Error)]
57pub enum PromiseError {
58    /// The resolver has not yet sent a value.
59    #[error("value not yet sent")]
60    Empty,
61    /// The resolver was dropped before sending a value.
62    #[error("closed before a value was sent")]
63    Dropped,
64}
65
66// SAFETY: must not be clonable, as that would allow simultaneous write access to the `Inner` value.
67/// The resolve/send half of a promise. Created via [`channel`], and used to resolve the corresponding [`Promise`] with
68/// a value.
69pub struct Resolve<T> {
70    inner: Option<Weak<Inner<T>>>,
71}
72impl<T> Resolve<T> {
73    /// Resolves the promise with the given value, consuming this `Resolve` in the process.
74    ///
75    /// This will panic if the promise has already been resolved.
76    pub fn into_resolve(mut self, value: T) {
77        self.resolve(value).unwrap_or_else(|_| panic!("already resolved"));
78    }
79
80    /// Resolves the promise with the given value, returning an error if the promise has already been resolved.
81    pub fn resolve(&mut self, value: T) -> Result<(), T> {
82        let Some(inner) = self.inner.take() else {
83            return Err(value);
84        };
85
86        if let Some(inner) = inner.upgrade() {
87            // SAFETY: `&mut self: Resolve` has exclusive access to `resolve` once.
88            unsafe {
89                inner.resolve(Some(value));
90            }
91        }
92        Ok(())
93    }
94}
95impl<T> Drop for Resolve<T> {
96    fn drop(&mut self) {
97        if let Some(inner) = self.inner.take().and_then(|weak| weak.upgrade()) {
98            // SAFETY: `&mut self: Resolve` has exclusive access to call resolve once. Because we use `inner.take()`, we
99            // know `Self::resolve` was not called.
100            unsafe {
101                inner.resolve(None);
102            }
103        }
104    }
105}
106
107const BIT: u8 = 0b1;
108/// Flag for when the [`Resolve`] will no longer change the value of the [`Inner`], because it was either resolved or
109/// dropped.
110const FLAG_COMPLETED: u8 = BIT;
111/// Flag for when the value is set and can be read.
112const FLAG_VALUE_SET: u8 = BIT << 1;
113
114/// # Safety
115///
116/// Any thread with an `&self` may access the `value` field according the following rules:
117///
118///  1. Iff NOT `FLAG_COMPLETED`, the `value` field may be initialized by the owning `Resolve` once.
119///  2. Iff `FLAG_VALUE_SET`, the `value` field may be accessed immutably by any thread.
120///
121/// If `FLAG_COMPLETED` but NOT `FLAG_VALUE_SET`, then the owning `Resolve` was dropped before the value was set.
122///
123/// # Table of possible states
124/// |                       | NOT `FLAG_COMPLETED`  | `FLAG_COMPLETED`                   |
125/// |-----------------------|-----------------------|------------------------------------|
126/// | NOT `FLAG_VALUE_SET`  | Waiting for `Resolve` | `Resolve` dropped, `value` not set |
127/// | `FLAG_VALUE_SET`      | INVALID               | `value` set, accessible            |
128struct Inner<T> {
129    flag: AtomicU8,
130    value: UnsafeCell<MaybeUninit<T>>,
131    /// List of wakers waiting for the value to be set.
132    wakers: Mutex<Vec<Waker>>,
133}
134impl<T> Inner<T> {
135    const fn new() -> Self {
136        Self {
137            flag: AtomicU8::new(0),
138            value: UnsafeCell::new(MaybeUninit::uninit()),
139            wakers: Mutex::new(Vec::new()),
140        }
141    }
142
143    /// Polls the promise, storing the `Waker` if needed.
144    fn poll_get<'a>(&'a self, cx: &mut Context<'_>) -> Poll<Option<&'a T>> {
145        if let Some(value_opt) = self.get() {
146            return Poll::Ready(value_opt);
147        }
148        {
149            // Acquire lock.
150            let mut wakers = self.wakers.lock().unwrap();
151            // Check again in case of race condition with resolver.
152            if let Some(value_opt) = self.get() {
153                return Poll::Ready(value_opt);
154            }
155            // Add the current waker to the list of wakers.
156            wakers.push(cx.waker().clone());
157        }
158        Poll::Pending
159    }
160
161    /// # Returns
162    /// * `None` if the value is not yet set.
163    /// * `Some(None)` if the resolver was dropped before setting a value.
164    /// * `Some(Some(&T))` if the value has been set.
165    fn get(&self) -> Option<Option<&T>> {
166        // Using acquire ordering so any threads that read a true from this
167        // atomic is able to read the value.
168        let flag = self.flag.load(Ordering::Acquire);
169        let completed = 0 != (flag & FLAG_COMPLETED);
170        if completed {
171            let value_set = 0 != (flag & FLAG_VALUE_SET);
172            if value_set {
173                // SAFETY: Value is initialized.
174                Some(Some(unsafe { &*(*self.value.get()).as_ptr() }))
175            } else {
176                // The resolver was dropped before setting a value.
177                Some(None)
178            }
179        } else {
180            // The value is not yet set.
181            None
182        }
183    }
184
185    /// SAFETY: The owning [`Resolve`] may call this once.
186    unsafe fn resolve(&self, value_or_dropped: Option<T>) {
187        let flag = if let Some(value) = value_or_dropped {
188            // SAFETY: `&mut self: Resolve` has exclusive access to set the value once.
189            unsafe {
190                self.value.get().write(MaybeUninit::new(value));
191            }
192            FLAG_COMPLETED | FLAG_VALUE_SET
193        } else {
194            // Dropped, do not set `FLAG_INITIALIZED`.
195            FLAG_COMPLETED
196        };
197        // Using release ordering so any threads that read a true from this
198        // atomic is able to read the value we just stored.
199        self.flag.store(flag, Ordering::Release);
200        let wakers = { std::mem::take(&mut *self.wakers.lock().unwrap()) };
201        wakers.into_iter().for_each(Waker::wake);
202    }
203}
204
205/// Since we can get immutable references to [`Self::value`], this is only `Sync` if `T` is `Sync`, otherwise this
206/// would allow sharing references of `!Sync` values across threads. We need `T` to be `Send` in order for this to be
207/// `Sync` because we can use [`Self::resolve`] to send values (of type T) across threads.
208unsafe impl<T: Sync + Send> Sync for Inner<T> {}
209
210/// Access to [`Self::value`] is guarded by the atomic operations on [`Self::flag`], so as long as `T` itself is `Send`
211/// it's safe to send it to another thread
212unsafe impl<T: Send> Send for Inner<T> {}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[tokio::test]
219    async fn test_basic() {
220        let (resolve, promise) = channel::<i32>();
221
222        let handle = tokio::task::spawn(async move {
223            resolve.into_resolve(42);
224        });
225
226        let value = promise.wait().await;
227        assert_eq!(Some(&42), value);
228
229        handle.await.unwrap();
230    }
231
232    #[tokio::test]
233    async fn test_multiple() {
234        let (resolve, promise) = channel::<i32>();
235        let promise1 = Arc::new(promise);
236        let promise2 = Arc::clone(&promise1);
237        let promise3 = Arc::clone(&promise1);
238
239        let read1 = tokio::task::spawn(async move {
240            let value = promise1.wait().await;
241            assert_eq!(Some(&42), value);
242        });
243        let read2 = tokio::task::spawn(async move {
244            let value = promise2.wait().await;
245            assert_eq!(Some(&42), value);
246        });
247        let read3 = tokio::task::spawn(async move {
248            let value = promise3.wait().await;
249            assert_eq!(Some(&42), value);
250        });
251        let resolve = tokio::task::spawn(async move {
252            resolve.into_resolve(42);
253        });
254
255        read1.await.unwrap();
256        read2.await.unwrap();
257        read3.await.unwrap();
258        resolve.await.unwrap();
259    }
260
261    #[tokio::test]
262    async fn test_try_get() {
263        let (resolve, promise) = channel::<i32>();
264
265        assert_eq!(promise.try_get(), Err(PromiseError::Empty));
266
267        resolve.into_resolve(42);
268
269        assert_eq!(promise.try_get(), Ok(&42));
270    }
271
272    #[tokio::test]
273    async fn test_dropped() {
274        let (resolve, promise) = channel::<i32>();
275
276        // The promise should still be active.
277        assert!(!promise.is_done());
278
279        // Drop the resolver without resolving it.
280        drop(resolve);
281
282        // The promise should be done
283        assert!(promise.is_done());
284
285        // The promise should return an error when trying to get the value.
286        assert_eq!(promise.try_get(), Err(PromiseError::Dropped));
287
288        // The promise should return None when waiting for the value.
289        assert_eq!(promise.wait().await, None);
290    }
291}