Skip to main content

async_err/
future_ext.rs

1use std::error::Error;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7/// Extension trait providing a `.with_context()` method for futures resolving to `Result<T, E>`.
8///
9/// This method allows attaching additional context to errors lazily, by supplying
10/// a closure that is only executed if the future resolves to an error.
11///
12/// # Example
13/// ```
14/// some_async_fn()
15///     .with_context(|err| format!("Failed due to: {}", err))
16///     .await;
17/// ```
18pub trait AsyncResultExt<T, E>: Future<Output = Result<T, E>> + Sized {
19    /// Adds context to an error produced by this future lazily.
20    ///
21    /// The closure `ctx` is called only if the future resolves to an error, producing
22    /// a string context to be attached to the error.
23    ///
24    /// # Parameters
25    /// - `ctx`: closure to create context string from error reference
26    ///
27    /// # Returns
28    /// A future that resolves to `Result<T, AsyncError<E>>`, where errors are wrapped to include context.
29    fn with_context<C>(self, ctx: C) -> WithContext<Self, E, C>
30    where
31        C: FnOnce(&E) -> String,
32    {
33        WithContext {
34            future: self,
35            context: Some(ctx),
36            _marker: PhantomData,
37        }
38    }
39}
40
41impl<T, E, Fut> AsyncResultExt<T, E> for Fut where Fut: Future<Output = Result<T, E>> + Sized {}
42
43/// Future wrapper produced by `.with_context()` to add error context.
44///
45/// Wraps the original future, and on error, attaches the context string lazily generated
46/// by the stored closure.
47pub struct WithContext<Fut, E, C> {
48    future: Fut,
49    context: Option<C>,
50    _marker: PhantomData<E>,
51}
52
53impl<Fut, T, E, C> Future for WithContext<Fut, E, C>
54where
55    Fut: Future<Output = Result<T, E>>,
56    E: Error + 'static,
57    C: FnOnce(&E) -> String,
58{
59    type Output = Result<T, crate::error::AsyncError<E>>;
60
61    /// Polls the wrapped future, converting any error by adding context.
62    ///
63    /// If the wrapped future resolves to `Ok`, passes the value through.
64    /// If `Err`, applies the context closure, wraps the error (without invoking hooks!).
65    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
66        // Safety: projected pinned fields can be safely accessed
67        let this = unsafe { self.get_unchecked_mut() };
68        let fut = unsafe { Pin::new_unchecked(&mut this.future) };
69
70        match fut.poll(cx) {
71            Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
72            Poll::Ready(Err(err)) => {
73                let ctx = this.context.take().map(|f| f(&err));
74                let wrapped =
75                    crate::error::AsyncError::new(err).with_context(ctx.unwrap_or_default());
76
77                // Do NOT invoke hooks here — defer hook invocation to caller to avoid duplicates
78
79                Poll::Ready(Err(wrapped))
80            }
81            Poll::Pending => Poll::Pending,
82        }
83    }
84}
85
86/// Extension trait adding `.and_then_async()` for chaining futures returning results.
87///
88/// This allows chaining asynchronous computations that depend on the success of the previous one.
89pub trait AsyncResultChainExt<T, E>: Future<Output = Result<T, E>> + Sized {
90    /// Chains an asynchronous computation to execute if the previous future resolves to `Ok`.
91    ///
92    /// The closure `f` takes the successful value and returns a new future producing a `Result`.
93    ///
94    /// # Parameters
95    /// - `f`: the chaining closure producing the next future.
96    ///
97    /// # Returns
98    /// A future that resolves to the chained computation’s `Result`.
99    fn and_then_async<Fut, F, U>(self, f: F) -> AndThenAsync<Self, Fut, F>
100    where
101        F: FnOnce(T) -> Fut,
102        Fut: Future<Output = Result<U, E>>,
103    {
104        AndThenAsync {
105            state: AndThenAsyncState::First(self, Some(f)),
106        }
107    }
108}
109
110impl<T, E, F> AsyncResultChainExt<T, E> for F where F: Future<Output = Result<T, E>> + Sized {}
111
112/// Internal enum representing the current state of the chained async future.
113pub enum AndThenAsyncState<Fut1, Fut2, F> {
114    First(Fut1, Option<F>),
115    Second(Fut2),
116    Done,
117}
118
119/// Future that chains two async computations sequentially.
120///
121/// Internally manages polling of the first, then the second future produced by the chaining closure.
122pub struct AndThenAsync<Fut1, Fut2, F> {
123    state: AndThenAsyncState<Fut1, Fut2, F>,
124}
125
126impl<Fut1, Fut2, F, T, U, E> Future for AndThenAsync<Fut1, Fut2, F>
127where
128    Fut1: Future<Output = Result<T, E>>,
129    Fut2: Future<Output = Result<U, E>>,
130    F: FnOnce(T) -> Fut2,
131{
132    type Output = Result<U, E>;
133
134    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        // Safety: Moving pinned fields in pattern matching is allowed here.
136        let this = unsafe { self.get_unchecked_mut() };
137        loop {
138            match &mut this.state {
139                AndThenAsyncState::First(fut1, maybe_f) => {
140                    let fut1_pin = unsafe { Pin::new_unchecked(fut1) };
141                    match fut1_pin.poll(cx) {
142                        Poll::Pending => return Poll::Pending,
143                        Poll::Ready(result) => match result {
144                            Ok(value) => {
145                                let f = maybe_f.take().expect("FnOnce already taken");
146                                let fut2 = f(value);
147                                this.state = AndThenAsyncState::Second(fut2);
148                            }
149                            Err(e) => {
150                                this.state = AndThenAsyncState::Done;
151                                return Poll::Ready(Err(e));
152                            }
153                        },
154                    }
155                }
156                AndThenAsyncState::Second(fut2) => {
157                    let fut2_pin = unsafe { Pin::new_unchecked(fut2) };
158                    match fut2_pin.poll(cx) {
159                        Poll::Pending => return Poll::Pending,
160                        Poll::Ready(result) => {
161                            this.state = AndThenAsyncState::Done;
162                            return Poll::Ready(result);
163                        }
164                    }
165                }
166                AndThenAsyncState::Done => panic!("Polled after completion"),
167            }
168        }
169    }
170}