asynx/
sync.rs

1//! The sync implementation of exception context.
2
3use core::sync::atomic::AtomicU8;
4
5const NO_EXCEPTION: u8 = 0;
6const THROWING: u8 = 1;
7const THROWN: u8 = 2;
8const MOVED: u8 = 3;
9
10/// `ExceptionContext` provides the context for throwing and catching exception.
11///
12/// This exception context implements `Send`/`Sync`, but the main purpose is to make
13/// the future return from an async function that uses this context
14/// has `Send`/`Sync`. It is not designed (though safe) to be operated concurrently.
15pub struct ExceptionContext<E> {
16    status: core::sync::atomic::AtomicU8,
17    exception: core::cell::UnsafeCell<core::mem::MaybeUninit<E>>,
18}
19
20// SAFETY: There are no methods that deal with `&E`, so `E` is not required to be `Sync`. You can
21// imagine that `E` is wrapped in a [`SyncWrapper`], a type that doesn't allow creation of `&E`s
22// and is therefore unconditionally Sync.
23//
24// https://docs.rs/sync_wrapper/0.1/sync_wrapper/struct.SyncWrapper.html
25unsafe impl<E: Send> Send for ExceptionContext<E> {}
26// `E` must be `Send` as this type allows a shared reference to it to take ownership of `E`.
27unsafe impl<E: Send> Sync for ExceptionContext<E> {}
28
29impl<E> Drop for ExceptionContext<E> {
30    fn drop(&mut self) {
31        if *self.status.get_mut() == THROWN {
32            // SAFETY: when the status is `THROWN`, `exception` has an unmoved initialized value.
33            let e = unsafe { self.exception.get().read().assume_init() };
34            drop(e)
35        }
36    }
37}
38
39impl<E> Default for ExceptionContext<E> {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl<E> ExceptionContext<E> {
46    /// Create a new exception context.
47    pub const fn new() -> Self {
48        Self {
49            status: AtomicU8::new(0),
50            exception: core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit()),
51        }
52    }
53
54    /// Throws an exception. You should always `await` the result.
55    ///
56    /// Example:
57    ///
58    /// ```rust
59    /// tokio_test::block_on(async {
60    ///     let r = tokio::spawn(async {
61    ///         asynx::sync::ExceptionContext::<String>::new()
62    ///             .catch(|ctx| async move {
63    ///                 ctx.throw("failed".to_string()).await;
64    ///                 unreachable!()
65    ///             }).await
66    ///      }).await.unwrap();
67    ///
68    ///     assert_eq!(Err("failed".to_string()), r)
69    /// })
70    /// ```
71    pub async fn throw(&self, exception: E) -> ! {
72        if self
73            .status
74            .compare_exchange(
75                NO_EXCEPTION,
76                THROWING,
77                // No specific ordering is required on success because it's not possible for
78                // `exception` to be written to before this swap succeds; there are no Release
79                // fences that occur after a write to `exception` that any Acquire fence here would
80                // need to synchronize with.
81                core::sync::atomic::Ordering::Relaxed,
82                core::sync::atomic::Ordering::Relaxed,
83            )
84            .is_err()
85        {
86            panic!("`throw` calls more than once")
87        }
88        // SAFETY: we compare-exchange from NO_EXCEPTION to THROWING,
89        // and the status won't be `NO_EXCEPTION` again.
90        // So the compare-exchange will only succeed once, so there is no concurrent write.
91        // Also, all reads on `exception` are performed only after status being written `THROWN`.
92        // This happens after the exception is written,
93        // so there is no concurrent read.
94        unsafe { (&mut *self.exception.get()).write(exception) };
95        // Release is necessary to ensure that any (correctly Acquired) reads of `exception` that
96        // happen after this point will be able to see our newly-written value.
97        self.status
98            .store(THROWN, core::sync::atomic::Ordering::Release);
99        core::future::pending().await
100    }
101
102    /// Executes the function `f` providing the context, then returns a Future that
103    /// catches the thrown exception.
104    ///
105    /// Example:
106    ///
107    /// ```rust
108    /// tokio_test::block_on(async {
109    ///     let r = tokio::spawn(async {
110    ///         asynx::sync::ExceptionContext::<String>::new()
111    ///             .catch(|_| async {
112    ///                 "success".to_string()
113    ///             }).await
114    ///     }).await.unwrap();
115    ///     assert_eq!(Ok("success".to_string()), r);
116    ///
117    ///     let r = tokio::spawn(async {
118    ///         asynx::sync::ExceptionContext::<String>::new()
119    ///             .catch(|ctx| async {
120    ///                 ctx.throw("failed".to_string()).await;
121    ///                 unreachable!()
122    ///             }).await
123    ///      }).await.unwrap();
124    ///
125    ///     assert_eq!(Err("failed".to_string()), r)
126    /// })
127    /// ```
128    ///
129    /// Note that unlike the unsync version, for [crate::unsync::ExceptionContext],
130    /// you can only call `catch` once on each context. Calling multiple times causes
131    /// panic. You need to create a context for each catching.
132    ///
133    /// ```
134    /// use asynx::sync::ExceptionContext;
135    ///
136    /// tokio_test::block_on(async {
137    ///     let r = ExceptionContext::<u32>::new()
138    ///         .catch(|ctx| async {
139    ///             let r = ExceptionContext::<u32>::new().catch(|ctx| async {
140    ///                 ctx.throw(1).await
141    ///             }).await;
142    ///             assert_eq!(Err(1), r);
143    ///             
144    ///             let r = ExceptionContext::<u32>::new().catch(|ctx| async {
145    ///                 ctx.throw(2).await
146    ///             }).await;
147    ///             assert_eq!(Err(2), r);
148    ///
149    ///             ctx.throw(3).await;
150    ///         }).await;
151    ///     assert_eq!(Err(3), r)
152    /// })
153    /// ```
154    pub fn catch<'a, Fu: core::future::Future, F: Fn(&'a Self) -> Fu>(
155        &'a self,
156        f: F,
157    ) -> Catching<'a, E, Fu> {
158        Catching {
159            ctx: self,
160            future: f(self),
161        }
162    }
163
164    fn try_take_exception(&self) -> Option<E> {
165        if self
166            .status
167            .compare_exchange(
168                THROWN,
169                MOVED,
170                // On success, Acquire is necessary to ensure that our read of `exception`
171                // happens-after it is written to by `throw`.
172                core::sync::atomic::Ordering::Acquire,
173                // We don't read any shared state on failure so there's no need for any memory
174                // orderings.
175                core::sync::atomic::Ordering::Relaxed,
176            )
177            .is_ok()
178        {
179            // SAFETY: status is changed from THROWN to MOVED,
180            // but writes on exception only happens after status changed
181            // from NO_EXCEPTION to THROWING, so there is no concurrent write.
182            //
183            // Because the status was THROWN before this write, the `exception` has an initialized value.
184            // We can move it out because after status becomes MOVED, the value won't be dropped by the context itself.
185            Some(unsafe { self.exception.get().read().assume_init() })
186        } else {
187            None
188        }
189    }
190}
191
192pin_project_lite::pin_project! {
193    /// A wrapper future that catches the exception.
194    ///
195    /// It outputs a result with the exception as error.
196    pub struct Catching<'a, E, F> {
197        ctx: &'a ExceptionContext<E>,
198        #[pin]
199        future: F,
200    }
201}
202
203impl<'a, E, F: core::future::Future> core::future::Future for Catching<'a, E, F> {
204    type Output = Result<F::Output, E>;
205
206    fn poll(
207        self: core::pin::Pin<&mut Self>,
208        cx: &mut core::task::Context<'_>,
209    ) -> core::task::Poll<Self::Output> {
210        let this = self.project();
211        let p = this.future.poll(cx);
212        if let Some(exception) = this.ctx.try_take_exception() {
213            core::task::Poll::Ready(Err(exception))
214        } else {
215            p.map(Ok)
216        }
217    }
218}