async_promise/lib.rs
1#![warn(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", env!("CARGO_PKG_README")))]
3
4use std::cell::UnsafeCell;
5use std::mem::MaybeUninit;
6use std::sync::atomic::{AtomicU8, Ordering};
7use std::sync::{Arc, Mutex, Weak};
8use std::task::{Context, Poll, Waker};
9
10/// Creates a new promise and resolver pair.
11pub fn channel<T>() -> (Resolve<T>, Promise<T>) {
12 let inner = Arc::new(Inner::new());
13 (
14 Resolve {
15 inner: Some(Arc::downgrade(&inner)),
16 },
17 Promise { inner },
18 )
19}
20
21/// A container for a value of type `T` that may not yet be resolved.
22///
23/// Multiple consumers may obtain the value via shared reference (`&`) once resolved.
24pub struct Promise<T> {
25 inner: Arc<Inner<T>>,
26}
27impl<T> Promise<T> {
28 /// Waits for the promise to be resolved, returning the value once set.
29 ///
30 /// Returns `None` if the resolver was dropped before sending a value.
31 pub async fn wait(&self) -> Option<&T> {
32 std::future::poll_fn(|cx| self.inner.poll_get(cx)).await
33 }
34
35 /// Attempts to get the value if it has been resolved, returning an error if not.
36 ///
37 /// Returns [`PromiseError::Empty`] if the value has not yet been set, and [`PromiseError::Dropped`] if the resolver
38 /// was dropped before sending a value.
39 pub fn try_get(&self) -> Result<&T, PromiseError> {
40 self.inner
41 .get()
42 .ok_or(PromiseError::Empty)
43 .and_then(|value_opt| value_opt.ok_or(PromiseError::Dropped))
44 }
45}
46
47/// An error that may occur when trying to get the value from a [`Promise`], see [`Promise::try_get`].
48#[derive(Debug, Eq, PartialEq, Clone, thiserror::Error)]
49pub enum PromiseError {
50 /// The resolver has not yet sent a value.
51 #[error("value not yet sent")]
52 Empty,
53 /// The resolver was dropped before sending a value.
54 #[error("closed before a value was sent")]
55 Dropped,
56}
57
58// SAFETY: must not be clonable, as that would allow simultaneous write access to the `Inner` value.
59/// The resolve/send half of a promise. Created via [`channel`], and used to resolve the corresponding [`Promise`] with
60/// a value.
61pub struct Resolve<T> {
62 inner: Option<Weak<Inner<T>>>,
63}
64impl<T> Resolve<T> {
65 /// Resolves the promise with the given value, consuming this `Resolve` in the process.
66 ///
67 /// This will panic if the promise has already been resolved.
68 pub fn into_resolve(mut self, value: T) {
69 self.resolve(value).unwrap_or_else(|_| panic!("already resolved"));
70 }
71
72 /// Resolves the promise with the given value, returning an error if the promise has already been resolved.
73 pub fn resolve(&mut self, value: T) -> Result<(), T> {
74 let Some(inner) = self.inner.take() else {
75 return Err(value);
76 };
77
78 if let Some(inner) = inner.upgrade() {
79 // SAFETY: `&mut self: Resolve` has exclusive access to `resolve` once.
80 unsafe {
81 inner.resolve(Some(value));
82 }
83 }
84 Ok(())
85 }
86}
87impl<T> Drop for Resolve<T> {
88 fn drop(&mut self) {
89 if let Some(inner) = self.inner.take().and_then(|weak| weak.upgrade()) {
90 // SAFETY: `&mut self: Resolve` has exclusive access to call resolve once. Because we use `inner.take()`, we
91 // know `Self::resolve` was not called.
92 unsafe {
93 inner.resolve(None);
94 }
95 }
96 }
97}
98
99const BIT: u8 = 0b1;
100/// Flag for when the [`Resolve`] will no longer change the value of the [`Inner`], because it was either resolved or
101/// dropped.
102const FLAG_COMPLETED: u8 = BIT;
103/// Flag for when the value is set and can be read.
104const FLAG_VALUE_SET: u8 = BIT << 1;
105
106/// # Safety
107///
108/// Any thread with an `&self` may access the `value` field according the following rules:
109///
110/// 1. Iff NOT `FLAG_COMPLETED`, the `value` field may be initialized by the owning `Resolve` once.
111/// 2. Iff `FLAG_VALUE_SET`, the `value` field may be accessed immutably by any thread.
112///
113/// If `FLAG_COMPLETED` but NOT `FLAG_VALUE_SET`, then the owning `Resolve` was dropped before the value was set.
114///
115/// # Table of possible states
116/// | | NOT `FLAG_COMPLETED` | `FLAG_COMPLETED` |
117/// |-----------------------|-----------------------|------------------------------------|
118/// | NOT `FLAG_VALUE_SET` | Waiting for `Resolve` | `Resolve` dropped, `value` not set |
119/// | `FLAG_VALUE_SET` | INVALID | `value` set, accessible |
120struct Inner<T> {
121 flag: AtomicU8,
122 value: UnsafeCell<MaybeUninit<T>>,
123 /// List of wakers waiting for the value to be set.
124 wakers: Mutex<Vec<Waker>>,
125}
126impl<T> Inner<T> {
127 const fn new() -> Self {
128 Self {
129 flag: AtomicU8::new(0),
130 value: UnsafeCell::new(MaybeUninit::uninit()),
131 wakers: Mutex::new(Vec::new()),
132 }
133 }
134
135 /// Polls the promise, storing the `Waker` if needed.
136 fn poll_get<'a>(&'a self, cx: &mut Context<'_>) -> Poll<Option<&'a T>> {
137 if let Some(value_opt) = self.get() {
138 return Poll::Ready(value_opt);
139 }
140 {
141 // Acquire lock.
142 let mut wakers = self.wakers.lock().unwrap();
143 // Check again in case of race condition with resolver.
144 if let Some(value_opt) = self.get() {
145 return Poll::Ready(value_opt);
146 }
147 // Add the current waker to the list of wakers.
148 wakers.push(cx.waker().clone());
149 }
150 Poll::Pending
151 }
152
153 /// # Returns
154 /// * `None` if the value is not yet set.
155 /// * `Some(None)` if the resolver was dropped before setting a value.
156 /// * `Some(Some(&T))` if the value has been set.
157 fn get(&self) -> Option<Option<&T>> {
158 // Using acquire ordering so any threads that read a true from this
159 // atomic is able to read the value.
160 let flag = self.flag.load(Ordering::Acquire);
161 let completed = 0 != (flag & FLAG_COMPLETED);
162 if completed {
163 let value_set = 0 != (flag & FLAG_VALUE_SET);
164 if value_set {
165 // SAFETY: Value is initialized.
166 Some(Some(unsafe { &*(*self.value.get()).as_ptr() }))
167 } else {
168 // The resolver was dropped before setting a value.
169 Some(None)
170 }
171 } else {
172 // The value is not yet set.
173 None
174 }
175 }
176
177 /// SAFETY: The owning [`Resolve`] may call this once.
178 unsafe fn resolve(&self, value_or_dropped: Option<T>) {
179 let flag = if let Some(value) = value_or_dropped {
180 // SAFETY: `&mut self: Resolve` has exclusive access to set the value once.
181 unsafe {
182 self.value.get().write(MaybeUninit::new(value));
183 }
184 FLAG_COMPLETED | FLAG_VALUE_SET
185 } else {
186 // Dropped, do not set `FLAG_INITIALIZED`.
187 FLAG_COMPLETED
188 };
189 // Using release ordering so any threads that read a true from this
190 // atomic is able to read the value we just stored.
191 self.flag.store(flag, Ordering::Release);
192 let wakers = { std::mem::take(&mut *self.wakers.lock().unwrap()) };
193 wakers.into_iter().for_each(Waker::wake);
194 }
195}
196
197/// Since we can get immutable references to [`Self::value`], this is only `Sync` if `T` is `Sync`, otherwise this
198/// would allow sharing references of `!Sync` values across threads. We need `T` to be `Send` in order for this to be
199/// `Sync` because we can use [`Self::resolve`] to send values (of type T) across threads.
200unsafe impl<T: Sync + Send> Sync for Inner<T> {}
201
202/// Access to [`Self::value`] is guarded by the atomic operations on [`Self::flag`], so as long as `T` itself is `Send`
203/// it's safe to send it to another thread
204unsafe impl<T: Send> Send for Inner<T> {}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[tokio::test]
211 async fn test_basic() {
212 let (resolve, promise) = channel::<i32>();
213
214 let handle = tokio::task::spawn(async move {
215 resolve.into_resolve(42);
216 });
217
218 let value = promise.wait().await;
219 assert_eq!(Some(&42), value);
220
221 handle.await.unwrap();
222 }
223
224 #[tokio::test]
225 async fn test_multiple() {
226 let (resolve, promise) = channel::<i32>();
227 let promise1 = Arc::new(promise);
228 let promise2 = Arc::clone(&promise1);
229 let promise3 = Arc::clone(&promise1);
230
231 let read1 = tokio::task::spawn(async move {
232 let value = promise1.wait().await;
233 assert_eq!(Some(&42), value);
234 });
235 let read2 = tokio::task::spawn(async move {
236 let value = promise2.wait().await;
237 assert_eq!(Some(&42), value);
238 });
239 let read3 = tokio::task::spawn(async move {
240 let value = promise3.wait().await;
241 assert_eq!(Some(&42), value);
242 });
243 let resolve = tokio::task::spawn(async move {
244 resolve.into_resolve(42);
245 });
246
247 read1.await.unwrap();
248 read2.await.unwrap();
249 read3.await.unwrap();
250 resolve.await.unwrap();
251 }
252
253 #[tokio::test]
254 async fn test_try_get() {
255 let (resolve, promise) = channel::<i32>();
256
257 assert_eq!(promise.try_get(), Err(PromiseError::Empty));
258
259 resolve.into_resolve(42);
260
261 assert_eq!(promise.try_get(), Ok(&42));
262 }
263
264 #[tokio::test]
265 async fn test_dropped() {
266 let (resolve, promise) = channel::<i32>();
267
268 // Drop the resolver without resolving it.
269 drop(resolve);
270
271 // The promise should return an error when trying to get the value.
272 assert_eq!(promise.try_get(), Err(PromiseError::Dropped));
273
274 // The promise should return None when waiting for the value.
275 assert_eq!(promise.wait().await, None);
276 }
277}