async_promise/lib.rs
1#![warn(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", env!("CARGO_PKG_README")))]
3
4use std::cell::UnsafeCell;
5use std::mem::MaybeUninit;
6use std::sync::atomic::{AtomicU8, Ordering};
7use std::sync::{Arc, Mutex, Weak};
8use std::task::{Context, Poll, Waker};
9
10/// Creates a new promise and resolver pair.
11pub fn channel<T>() -> (Resolve<T>, Promise<T>) {
12 let inner = Arc::new(Inner::new());
13 (
14 Resolve {
15 inner: Some(Arc::downgrade(&inner)),
16 },
17 Promise { inner },
18 )
19}
20
21/// A container for a value of type `T` that may not yet be resolved.
22///
23/// Multiple consumers may obtain the value via shared reference (`&`) once resolved.
24pub struct Promise<T> {
25 inner: Arc<Inner<T>>,
26}
27impl<T> Promise<T> {
28 /// Waits for the promise to be resolved, returning the value once set.
29 ///
30 /// Returns `None` if the resolver was dropped before sending a value.
31 pub async fn wait(&self) -> Option<&T> {
32 std::future::poll_fn(|cx| self.inner.poll_get(cx)).await
33 }
34
35 /// Attempts to get the value if it has been resolved, returning an error if not.
36 ///
37 /// Returns [`PromiseError::Empty`] if the value has not yet been set, and [`PromiseError::Dropped`] if the resolver
38 /// was dropped before sending a value.
39 pub fn try_get(&self) -> Result<&T, PromiseError> {
40 self.inner
41 .get()
42 .ok_or(PromiseError::Empty)
43 .and_then(|value_opt| value_opt.ok_or(PromiseError::Dropped))
44 }
45
46 /// Returns if the `Resolver` has been resolved OR dropped.
47 ///
48 /// * `true` indicates the `Promise` is complete, either by resolving or dropping.
49 /// * `false` indicates the value may still resolve in the future.
50 pub fn is_done(&self) -> bool {
51 self.inner.get().is_some()
52 }
53}
54
55/// An error that may occur when trying to get the value from a [`Promise`], see [`Promise::try_get`].
56#[derive(Debug, Eq, PartialEq, Clone, thiserror::Error)]
57pub enum PromiseError {
58 /// The resolver has not yet sent a value.
59 #[error("value not yet sent")]
60 Empty,
61 /// The resolver was dropped before sending a value.
62 #[error("closed before a value was sent")]
63 Dropped,
64}
65
66// SAFETY: must not be clonable, as that would allow simultaneous write access to the `Inner` value.
67/// The resolve/send half of a promise. Created via [`channel`], and used to resolve the corresponding [`Promise`] with
68/// a value.
69pub struct Resolve<T> {
70 inner: Option<Weak<Inner<T>>>,
71}
72impl<T> Resolve<T> {
73 /// Resolves the promise with the given value, consuming this `Resolve` in the process.
74 ///
75 /// This will panic if the promise has already been resolved.
76 pub fn into_resolve(mut self, value: T) {
77 self.resolve(value).unwrap_or_else(|_| panic!("already resolved"));
78 }
79
80 /// Resolves the promise with the given value, returning an error if the promise has already been resolved.
81 pub fn resolve(&mut self, value: T) -> Result<(), T> {
82 let Some(inner) = self.inner.take() else {
83 return Err(value);
84 };
85
86 if let Some(inner) = inner.upgrade() {
87 // SAFETY: `&mut self: Resolve` has exclusive access to `resolve` once.
88 unsafe {
89 inner.resolve(Some(value));
90 }
91 }
92 Ok(())
93 }
94}
95impl<T> Drop for Resolve<T> {
96 fn drop(&mut self) {
97 if let Some(inner) = self.inner.take().and_then(|weak| weak.upgrade()) {
98 // SAFETY: `&mut self: Resolve` has exclusive access to call resolve once. Because we use `inner.take()`, we
99 // know `Self::resolve` was not called.
100 unsafe {
101 inner.resolve(None);
102 }
103 }
104 }
105}
106
107const BIT: u8 = 0b1;
108/// Flag for when the [`Resolve`] will no longer change the value of the [`Inner`], because it was either resolved or
109/// dropped.
110const FLAG_COMPLETED: u8 = BIT;
111/// Flag for when the value is set and can be read.
112const FLAG_VALUE_SET: u8 = BIT << 1;
113
114/// # Safety
115///
116/// Any thread with an `&self` may access the `value` field according the following rules:
117///
118/// 1. Iff NOT `FLAG_COMPLETED`, the `value` field may be initialized by the owning `Resolve` once.
119/// 2. Iff `FLAG_VALUE_SET`, the `value` field may be accessed immutably by any thread.
120///
121/// If `FLAG_COMPLETED` but NOT `FLAG_VALUE_SET`, then the owning `Resolve` was dropped before the value was set.
122///
123/// # Table of possible states
124/// | | NOT `FLAG_COMPLETED` | `FLAG_COMPLETED` |
125/// |-----------------------|-----------------------|------------------------------------|
126/// | NOT `FLAG_VALUE_SET` | Waiting for `Resolve` | `Resolve` dropped, `value` not set |
127/// | `FLAG_VALUE_SET` | INVALID | `value` set, accessible |
128struct Inner<T> {
129 flag: AtomicU8,
130 value: UnsafeCell<MaybeUninit<T>>,
131 /// List of wakers waiting for the value to be set.
132 wakers: Mutex<Vec<Waker>>,
133}
134impl<T> Inner<T> {
135 const fn new() -> Self {
136 Self {
137 flag: AtomicU8::new(0),
138 value: UnsafeCell::new(MaybeUninit::uninit()),
139 wakers: Mutex::new(Vec::new()),
140 }
141 }
142
143 /// Polls the promise, storing the `Waker` if needed.
144 fn poll_get<'a>(&'a self, cx: &mut Context<'_>) -> Poll<Option<&'a T>> {
145 if let Some(value_opt) = self.get() {
146 return Poll::Ready(value_opt);
147 }
148 {
149 // Acquire lock.
150 let mut wakers = self.wakers.lock().unwrap();
151 // Check again in case of race condition with resolver.
152 if let Some(value_opt) = self.get() {
153 return Poll::Ready(value_opt);
154 }
155 // Add the current waker to the list of wakers.
156 wakers.push(cx.waker().clone());
157 }
158 Poll::Pending
159 }
160
161 /// # Returns
162 /// * `None` if the value is not yet set.
163 /// * `Some(None)` if the resolver was dropped before setting a value.
164 /// * `Some(Some(&T))` if the value has been set.
165 fn get(&self) -> Option<Option<&T>> {
166 // Using acquire ordering so any threads that read a true from this
167 // atomic is able to read the value.
168 let flag = self.flag.load(Ordering::Acquire);
169 let completed = 0 != (flag & FLAG_COMPLETED);
170 if completed {
171 let value_set = 0 != (flag & FLAG_VALUE_SET);
172 if value_set {
173 // SAFETY: Value is initialized.
174 Some(Some(unsafe { &*(*self.value.get()).as_ptr() }))
175 } else {
176 // The resolver was dropped before setting a value.
177 Some(None)
178 }
179 } else {
180 // The value is not yet set.
181 None
182 }
183 }
184
185 /// SAFETY: The owning [`Resolve`] may call this once.
186 unsafe fn resolve(&self, value_or_dropped: Option<T>) {
187 let flag = if let Some(value) = value_or_dropped {
188 // SAFETY: `&mut self: Resolve` has exclusive access to set the value once.
189 unsafe {
190 self.value.get().write(MaybeUninit::new(value));
191 }
192 FLAG_COMPLETED | FLAG_VALUE_SET
193 } else {
194 // Dropped, do not set `FLAG_INITIALIZED`.
195 FLAG_COMPLETED
196 };
197 // Using release ordering so any threads that read a true from this
198 // atomic is able to read the value we just stored.
199 self.flag.store(flag, Ordering::Release);
200 let wakers = { std::mem::take(&mut *self.wakers.lock().unwrap()) };
201 wakers.into_iter().for_each(Waker::wake);
202 }
203}
204
205/// Since we can get immutable references to [`Self::value`], this is only `Sync` if `T` is `Sync`, otherwise this
206/// would allow sharing references of `!Sync` values across threads. We need `T` to be `Send` in order for this to be
207/// `Sync` because we can use [`Self::resolve`] to send values (of type T) across threads.
208unsafe impl<T: Sync + Send> Sync for Inner<T> {}
209
210/// Access to [`Self::value`] is guarded by the atomic operations on [`Self::flag`], so as long as `T` itself is `Send`
211/// it's safe to send it to another thread
212unsafe impl<T: Send> Send for Inner<T> {}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[tokio::test]
219 async fn test_basic() {
220 let (resolve, promise) = channel::<i32>();
221
222 let handle = tokio::task::spawn(async move {
223 resolve.into_resolve(42);
224 });
225
226 let value = promise.wait().await;
227 assert_eq!(Some(&42), value);
228
229 handle.await.unwrap();
230 }
231
232 #[tokio::test]
233 async fn test_multiple() {
234 let (resolve, promise) = channel::<i32>();
235 let promise1 = Arc::new(promise);
236 let promise2 = Arc::clone(&promise1);
237 let promise3 = Arc::clone(&promise1);
238
239 let read1 = tokio::task::spawn(async move {
240 let value = promise1.wait().await;
241 assert_eq!(Some(&42), value);
242 });
243 let read2 = tokio::task::spawn(async move {
244 let value = promise2.wait().await;
245 assert_eq!(Some(&42), value);
246 });
247 let read3 = tokio::task::spawn(async move {
248 let value = promise3.wait().await;
249 assert_eq!(Some(&42), value);
250 });
251 let resolve = tokio::task::spawn(async move {
252 resolve.into_resolve(42);
253 });
254
255 read1.await.unwrap();
256 read2.await.unwrap();
257 read3.await.unwrap();
258 resolve.await.unwrap();
259 }
260
261 #[tokio::test]
262 async fn test_try_get() {
263 let (resolve, promise) = channel::<i32>();
264
265 assert_eq!(promise.try_get(), Err(PromiseError::Empty));
266
267 resolve.into_resolve(42);
268
269 assert_eq!(promise.try_get(), Ok(&42));
270 }
271
272 #[tokio::test]
273 async fn test_dropped() {
274 let (resolve, promise) = channel::<i32>();
275
276 // The promise should still be active.
277 assert!(!promise.is_done());
278
279 // Drop the resolver without resolving it.
280 drop(resolve);
281
282 // The promise should be done
283 assert!(promise.is_done());
284
285 // The promise should return an error when trying to get the value.
286 assert_eq!(promise.try_get(), Err(PromiseError::Dropped));
287
288 // The promise should return None when waiting for the value.
289 assert_eq!(promise.wait().await, None);
290 }
291}