async_promise/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", env!("CARGO_PKG_README")))]
3
4use std::cell::UnsafeCell;
5use std::mem::MaybeUninit;
6use std::sync::atomic::{AtomicU8, Ordering};
7use std::sync::{Arc, Mutex, Weak};
8use std::task::{Context, Poll, Waker};
9
10/// Creates a new promise and resolver pair.
11pub fn channel<T>() -> (Resolve<T>, Promise<T>) {
12    let inner = Arc::new(Inner::new());
13    (
14        Resolve {
15            inner: Some(Arc::downgrade(&inner)),
16        },
17        Promise { inner },
18    )
19}
20
21/// A container for a value of type `T` that may not yet be resolved.
22///
23/// Multiple consumers may obtain the value via shared reference (`&`) once resolved.
24pub struct Promise<T> {
25    inner: Arc<Inner<T>>,
26}
27impl<T> Promise<T> {
28    /// Waits for the promise to be resolved, returning the value once set.
29    ///
30    /// Returns `None` if the resolver was dropped before sending a value.
31    pub async fn wait(&self) -> Option<&T> {
32        std::future::poll_fn(|cx| self.inner.poll_get(cx)).await
33    }
34
35    /// Attempts to get the value if it has been resolved, returning an error if not.
36    ///
37    /// Returns [`PromiseError::Empty`] if the value has not yet been set, and [`PromiseError::Dropped`] if the resolver
38    /// was dropped before sending a value.
39    pub fn try_get(&self) -> Result<&T, PromiseError> {
40        self.inner
41            .get()
42            .ok_or(PromiseError::Empty)
43            .and_then(|value_opt| value_opt.ok_or(PromiseError::Dropped))
44    }
45}
46
47/// An error that may occur when trying to get the value from a [`Promise`], see [`Promise::try_get`].
48#[derive(Debug, Eq, PartialEq, Clone, thiserror::Error)]
49pub enum PromiseError {
50    /// The resolver has not yet sent a value.
51    #[error("value not yet sent")]
52    Empty,
53    /// The resolver was dropped before sending a value.
54    #[error("closed before a value was sent")]
55    Dropped,
56}
57
58// SAFETY: must not be clonable, as that would allow simultaneous write access to the `Inner` value.
59/// The resolve/send half of a promise. Created via [`channel`], and used to resolve the corresponding [`Promise`] with
60/// a value.
61pub struct Resolve<T> {
62    inner: Option<Weak<Inner<T>>>,
63}
64impl<T> Resolve<T> {
65    /// Resolves the promise with the given value, consuming this `Resolve` in the process.
66    ///
67    /// This will panic if the promise has already been resolved.
68    pub fn into_resolve(mut self, value: T) {
69        self.resolve(value).unwrap_or_else(|_| panic!("already resolved"));
70    }
71
72    /// Resolves the promise with the given value, returning an error if the promise has already been resolved.
73    pub fn resolve(&mut self, value: T) -> Result<(), T> {
74        let Some(inner) = self.inner.take() else {
75            return Err(value);
76        };
77
78        if let Some(inner) = inner.upgrade() {
79            // SAFETY: `&mut self: Resolve` has exclusive access to `resolve` once.
80            unsafe {
81                inner.resolve(Some(value));
82            }
83        }
84        Ok(())
85    }
86}
87impl<T> Drop for Resolve<T> {
88    fn drop(&mut self) {
89        if let Some(inner) = self.inner.take().and_then(|weak| weak.upgrade()) {
90            // SAFETY: `&mut self: Resolve` has exclusive access to call resolve once. Because we use `inner.take()`, we
91            // know `Self::resolve` was not called.
92            unsafe {
93                inner.resolve(None);
94            }
95        }
96    }
97}
98
99const BIT: u8 = 0b1;
100/// Flag for when the [`Resolve`] will no longer change the value of the [`Inner`], because it was either resolved or
101/// dropped.
102const FLAG_COMPLETED: u8 = BIT;
103/// Flag for when the value is set and can be read.
104const FLAG_VALUE_SET: u8 = BIT << 1;
105
106/// # Safety
107///
108/// Any thread with an `&self` may access the `value` field according the following rules:
109///
110///  1. Iff NOT `FLAG_COMPLETED`, the `value` field may be initialized by the owning `Resolve` once.
111///  2. Iff `FLAG_VALUE_SET`, the `value` field may be accessed immutably by any thread.
112///
113/// If `FLAG_COMPLETED` but NOT `FLAG_VALUE_SET`, then the owning `Resolve` was dropped before the value was set.
114///
115/// # Table of possible states
116/// |                       | NOT `FLAG_COMPLETED`  | `FLAG_COMPLETED`                   |
117/// |-----------------------|-----------------------|------------------------------------|
118/// | NOT `FLAG_VALUE_SET`  | Waiting for `Resolve` | `Resolve` dropped, `value` not set |
119/// | `FLAG_VALUE_SET`      | INVALID               | `value` set, accessible            |
120struct Inner<T> {
121    flag: AtomicU8,
122    value: UnsafeCell<MaybeUninit<T>>,
123    /// List of wakers waiting for the value to be set.
124    wakers: Mutex<Vec<Waker>>,
125}
126impl<T> Inner<T> {
127    const fn new() -> Self {
128        Self {
129            flag: AtomicU8::new(0),
130            value: UnsafeCell::new(MaybeUninit::uninit()),
131            wakers: Mutex::new(Vec::new()),
132        }
133    }
134
135    /// Polls the promise, storing the `Waker` if needed.
136    fn poll_get<'a>(&'a self, cx: &mut Context<'_>) -> Poll<Option<&'a T>> {
137        if let Some(value_opt) = self.get() {
138            return Poll::Ready(value_opt);
139        }
140        {
141            // Acquire lock.
142            let mut wakers = self.wakers.lock().unwrap();
143            // Check again in case of race condition with resolver.
144            if let Some(value_opt) = self.get() {
145                return Poll::Ready(value_opt);
146            }
147            // Add the current waker to the list of wakers.
148            wakers.push(cx.waker().clone());
149        }
150        Poll::Pending
151    }
152
153    /// # Returns
154    /// * `None` if the value is not yet set.
155    /// * `Some(None)` if the resolver was dropped before setting a value.
156    /// * `Some(Some(&T))` if the value has been set.
157    fn get(&self) -> Option<Option<&T>> {
158        // Using acquire ordering so any threads that read a true from this
159        // atomic is able to read the value.
160        let flag = self.flag.load(Ordering::Acquire);
161        let completed = 0 != (flag & FLAG_COMPLETED);
162        if completed {
163            let value_set = 0 != (flag & FLAG_VALUE_SET);
164            if value_set {
165                // SAFETY: Value is initialized.
166                Some(Some(unsafe { &*(*self.value.get()).as_ptr() }))
167            } else {
168                // The resolver was dropped before setting a value.
169                Some(None)
170            }
171        } else {
172            // The value is not yet set.
173            None
174        }
175    }
176
177    /// SAFETY: The owning [`Resolve`] may call this once.
178    unsafe fn resolve(&self, value_or_dropped: Option<T>) {
179        let flag = if let Some(value) = value_or_dropped {
180            // SAFETY: `&mut self: Resolve` has exclusive access to set the value once.
181            unsafe {
182                self.value.get().write(MaybeUninit::new(value));
183            }
184            FLAG_COMPLETED | FLAG_VALUE_SET
185        } else {
186            // Dropped, do not set `FLAG_INITIALIZED`.
187            FLAG_COMPLETED
188        };
189        // Using release ordering so any threads that read a true from this
190        // atomic is able to read the value we just stored.
191        self.flag.store(flag, Ordering::Release);
192        let wakers = { std::mem::take(&mut *self.wakers.lock().unwrap()) };
193        wakers.into_iter().for_each(Waker::wake);
194    }
195}
196
197/// Since we can get immutable references to [`Self::value`], this is only `Sync` if `T` is `Sync`, otherwise this
198/// would allow sharing references of `!Sync` values across threads. We need `T` to be `Send` in order for this to be
199/// `Sync` because we can use [`Self::resolve`] to send values (of type T) across threads.
200unsafe impl<T: Sync + Send> Sync for Inner<T> {}
201
202/// Access to [`Self::value`] is guarded by the atomic operations on [`Self::flag`], so as long as `T` itself is `Send`
203/// it's safe to send it to another thread
204unsafe impl<T: Send> Send for Inner<T> {}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[tokio::test]
211    async fn test_basic() {
212        let (resolve, promise) = channel::<i32>();
213
214        let handle = tokio::task::spawn(async move {
215            resolve.into_resolve(42);
216        });
217
218        let value = promise.wait().await;
219        assert_eq!(Some(&42), value);
220
221        handle.await.unwrap();
222    }
223
224    #[tokio::test]
225    async fn test_multiple() {
226        let (resolve, promise) = channel::<i32>();
227        let promise1 = Arc::new(promise);
228        let promise2 = Arc::clone(&promise1);
229        let promise3 = Arc::clone(&promise1);
230
231        let read1 = tokio::task::spawn(async move {
232            let value = promise1.wait().await;
233            assert_eq!(Some(&42), value);
234        });
235        let read2 = tokio::task::spawn(async move {
236            let value = promise2.wait().await;
237            assert_eq!(Some(&42), value);
238        });
239        let read3 = tokio::task::spawn(async move {
240            let value = promise3.wait().await;
241            assert_eq!(Some(&42), value);
242        });
243        let resolve = tokio::task::spawn(async move {
244            resolve.into_resolve(42);
245        });
246
247        read1.await.unwrap();
248        read2.await.unwrap();
249        read3.await.unwrap();
250        resolve.await.unwrap();
251    }
252
253    #[tokio::test]
254    async fn test_try_get() {
255        let (resolve, promise) = channel::<i32>();
256
257        assert_eq!(promise.try_get(), Err(PromiseError::Empty));
258
259        resolve.into_resolve(42);
260
261        assert_eq!(promise.try_get(), Ok(&42));
262    }
263
264    #[tokio::test]
265    async fn test_dropped() {
266        let (resolve, promise) = channel::<i32>();
267
268        // Drop the resolver without resolving it.
269        drop(resolve);
270
271        // The promise should return an error when trying to get the value.
272        assert_eq!(promise.try_get(), Err(PromiseError::Dropped));
273
274        // The promise should return None when waiting for the value.
275        assert_eq!(promise.wait().await, None);
276    }
277}