owned_future/
funcs.rs

1use core::{
2    future::pending,
3    mem::MaybeUninit,
4    pin::Pin,
5    sync::atomic::AtomicPtr,
6    task::{Context, Poll, Waker},
7};
8
9use alloc::boxed::Box;
10
11use crate::{GetFut, TryGetFut};
12
13struct PollOnce {
14    completed: bool,
15}
16
17impl Future for PollOnce {
18    type Output = ();
19    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
20        if !self.completed {
21            self.completed = true;
22            Poll::Pending
23        } else {
24            Poll::Ready(())
25        }
26    }
27}
28
29/// Encapsulate a borrowed future along with it's owner
30#[deny(unsafe_code)]
31pub fn make<G>(val: G::Input, getter: G) -> Pin<Box<impl Future<Output = G::Output>>>
32where
33    G: GetFut,
34{
35    let mut future = Box::pin(async move {
36        let mut val = val;
37        let future = getter.get_fut(&mut val);
38        PollOnce { completed: false }.await;
39        future.await
40    });
41
42    let _poll = Future::poll(future.as_mut(), &mut Context::from_waker(Waker::noop()));
43    debug_assert!(matches!(_poll, Poll::Pending));
44
45    future
46}
47
48/// Try to encapsulate a borrowed future along with it's owner
49pub fn try_make<G>(
50    val: G::Input,
51    getter: G,
52) -> Result<(Pin<Box<impl Future<Output = G::Output>>>, G::Aux), (G::Input, G::Error)>
53where
54    G: TryGetFut,
55{
56    let mut result = MaybeUninit::<Result<G::Aux, (G::Input, G::Error)>>::uninit();
57    let result_ptr = AtomicPtr::new(result.as_mut_ptr());
58    let mut future = Box::pin(async move {
59        let mut val = val;
60        let result_ptr = result_ptr.into_inner();
61
62        let err = 'err: {
63            match getter.try_get_fut(&mut val) {
64                Ok((future, aux)) => {
65                    // SAFETY: this is safe to do because `result_ptr` lives in the stack frame
66                    // above us and we write to it exactly once prior to the first poll
67                    unsafe {
68                        result_ptr.write(Ok(aux));
69                    }
70
71                    // return `Pending` and pass control back
72                    PollOnce { completed: false }.await;
73
74                    return future.await;
75                }
76                Err(err) => break 'err err,
77            }
78        };
79
80        // SAFETY: this is safe to do because `result_ptr` lives in the stack frame above us and we
81        // write to it exactly once prior to the first poll
82        unsafe {
83            result_ptr.write(Err((val, err)));
84        }
85
86        // return `Pending` and pass control back
87        pending::<()>().await;
88
89        // The future should be forgotten and this should never be called
90        unreachable!();
91    });
92
93    let _poll = Future::poll(future.as_mut(), &mut Context::from_waker(Waker::noop()));
94    debug_assert!(matches!(_poll, Poll::Pending));
95
96    // SAFETY: this is safe to do because `result` is always written exactly once and always before
97    // the first poll
98    let result = unsafe { result.assume_init() };
99
100    result.map(|aux| (future, aux))
101}
102
103#[cfg(feature = "async")]
104mod async_feature {
105    use core::{
106        future::pending,
107        mem,
108        pin::Pin,
109        sync::atomic::AtomicPtr,
110        task::{Context, Poll},
111    };
112
113    use alloc::boxed::Box;
114    use pin_project_lite::pin_project;
115
116    use crate::{AsyncSendTryGetFut, AsyncTryGetFut, funcs::PollOnce};
117
118    enum AsyncTryMakeFuture<'a, G: AsyncTryGetFut<'a>> {
119        Input(G::Input, G),
120        Future(Pin<Box<dyn 'a + Future<Output = G::Output>>>),
121        Done,
122    }
123
124    pin_project! {
125        /// Try to encapsulate an async borrowed future along with it's owner
126        pub struct AsyncTry<'a, G: AsyncTryGetFut<'a>> {
127            result: Option<Result<G::Aux, (G::Input, G::Error)>>,
128            future: AsyncTryMakeFuture<'a, G>
129        }
130    }
131
132    impl<'a, G: AsyncTryGetFut<'a>> AsyncTry<'a, G> {
133        pub fn new(val: G::Input, getter: G) -> Self {
134            Self {
135                result: None,
136                future: AsyncTryMakeFuture::Input(val, getter),
137            }
138        }
139    }
140
141    /// An alias for the output type of [`AsyncTry`]
142    ///
143    /// This will stop being `Box<dyn _>` once either `type_alias_impl_trait` or
144    /// `impl_trait_in_assoc_type` stabilize
145    pub type AsyncTryOutput<'a, G> = Result<
146        (
147            Pin<Box<dyn 'a + Future<Output = <G as AsyncTryGetFut<'a>>::Output>>>,
148            <G as AsyncTryGetFut<'a>>::Aux,
149        ),
150        (
151            <G as AsyncTryGetFut<'a>>::Input,
152            <G as AsyncTryGetFut<'a>>::Error,
153        ),
154    >;
155
156    impl<'a, G: AsyncTryGetFut<'a>> Future for AsyncTry<'a, G> {
157        type Output = AsyncTryOutput<'a, G>;
158
159        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160            let future = match mem::replace(&mut self.future, AsyncTryMakeFuture::Done) {
161                AsyncTryMakeFuture::Done => unreachable!(),
162                AsyncTryMakeFuture::Input(val, getter) => {
163                    let result_ptr = AtomicPtr::new(&mut self.result);
164
165                    let mut future = Box::pin(async move {
166                        let mut val = val;
167                        let result_ptr = result_ptr.into_inner();
168                        let err = 'err: {
169                            match getter.async_try_get_fut(&mut val).await {
170                                Ok((future, aux)) => {
171                                    // SAFETY: this is safe to do because `result_ptr` is pinned by our
172                                    // poller and we write to it exactly once
173                                    unsafe {
174                                        debug_assert!((*result_ptr).is_none());
175                                        *result_ptr = Some(Ok(aux));
176                                    }
177
178                                    // return `Pending` and pass control back
179                                    PollOnce { completed: false }.await;
180
181                                    return future.await;
182                                }
183                                Err(err) => break 'err err,
184                            }
185                        };
186                        // SAFETY: this is safe to do because `result_ptr` is pinned by our poller and
187                        // we write to it exactly once
188                        unsafe {
189                            debug_assert!((*result_ptr).is_none());
190                            *result_ptr = Some(Err((val, err)));
191                        }
192
193                        // return `Pending` and pass control back
194                        pending::<()>().await;
195
196                        // The future should be forgotten and this should never be called
197                        unreachable!();
198                    });
199                    let _result = future.as_mut().poll(cx);
200                    debug_assert!(matches!(_result, Poll::Pending));
201                    future
202                }
203                AsyncTryMakeFuture::Future(mut future) => {
204                    let _result = future.as_mut().poll(cx);
205                    debug_assert!(matches!(_result, Poll::Pending));
206                    future
207                }
208            };
209
210            if let Some(result) = self.result.take() {
211                return Poll::Ready(result.map(|aux| (future, aux)));
212            }
213            self.future = AsyncTryMakeFuture::Future(future);
214
215            Poll::Pending
216        }
217    }
218
219    enum AsyncSendTryMakeFuture<'a, G: AsyncSendTryGetFut<'a>> {
220        Input(G::Input, G),
221        Future(Pin<Box<dyn 'a + Send + Future<Output = G::Output>>>),
222        Done,
223    }
224
225    pin_project! {
226        /// Try to encapsulate an async borrowed future along with it's owner
227        pub struct AsyncSendTry<'a, G: AsyncSendTryGetFut<'a>> {
228            result: Option<Result<G::Aux, (G::Input, G::Error)>>,
229            future: AsyncSendTryMakeFuture<'a, G>
230        }
231    }
232
233    impl<'a, G: AsyncSendTryGetFut<'a>> AsyncSendTry<'a, G> {
234        pub fn new(val: G::Input, getter: G) -> Self {
235            Self {
236                result: None,
237                future: AsyncSendTryMakeFuture::Input(val, getter),
238            }
239        }
240    }
241
242    /// An alias for the output type of [`AsyncTry`]
243    ///
244    /// This will stop being `Box<dyn _>` once either `type_alias_impl_trait` or
245    /// `impl_trait_in_assoc_type` stabilize
246    pub type AsyncSendTryOutput<'a, G> = Result<
247        (
248            Pin<Box<dyn 'a + Send + Future<Output = <G as AsyncSendTryGetFut<'a>>::Output>>>,
249            <G as AsyncSendTryGetFut<'a>>::Aux,
250        ),
251        (
252            <G as AsyncSendTryGetFut<'a>>::Input,
253            <G as AsyncSendTryGetFut<'a>>::Error,
254        ),
255    >;
256
257    impl<'a, G: AsyncSendTryGetFut<'a>> Future for AsyncSendTry<'a, G> {
258        type Output = AsyncSendTryOutput<'a, G>;
259
260        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
261            let future = match mem::replace(&mut self.future, AsyncSendTryMakeFuture::Done) {
262                AsyncSendTryMakeFuture::Done => unreachable!(),
263                AsyncSendTryMakeFuture::Input(val, getter) => {
264                    let result_ptr = AtomicPtr::new(&mut self.result);
265
266                    let mut future = Box::pin(async move {
267                        let mut val = val;
268                        let err = 'err: {
269                            match getter.async_send_try_get_fut(&mut val).await {
270                                Ok((future, aux)) => {
271                                    let result_ptr = result_ptr.into_inner();
272                                    // SAFETY: this is safe to do because `result_ptr` is pinned by our
273                                    // poller and we write to it exactly once
274                                    unsafe {
275                                        debug_assert!((*result_ptr).is_none());
276                                        *result_ptr = Some(Ok(aux));
277                                    }
278
279                                    // return `Pending` and pass control back
280                                    PollOnce { completed: false }.await;
281
282                                    return future.await;
283                                }
284                                Err(err) => break 'err err,
285                            }
286                        };
287                        let result_ptr = result_ptr.into_inner();
288                        // SAFETY: this is safe to do because `result_ptr` is pinned by our poller and
289                        // we write to it exactly once
290                        unsafe {
291                            debug_assert!((*result_ptr).is_none());
292                            *result_ptr = Some(Err((val, err)));
293                        }
294
295                        // return `Pending` and pass control back
296                        pending::<()>().await;
297
298                        // The future should be forgotten and this should never be called
299                        unreachable!();
300                    });
301                    let _result = future.as_mut().poll(cx);
302                    debug_assert!(matches!(_result, Poll::Pending));
303                    // todo!("blocked on rust-lang/rust#100013")
304                    future
305                }
306                AsyncSendTryMakeFuture::Future(mut future) => {
307                    let _result = future.as_mut().poll(cx);
308                    debug_assert!(matches!(_result, Poll::Pending));
309                    future
310                }
311            };
312
313            if let Some(result) = self.result.take() {
314                return Poll::Ready(result.map(|aux| (future, aux)));
315            }
316            self.future = AsyncSendTryMakeFuture::Future(future);
317
318            Poll::Pending
319        }
320    }
321}
322
323#[cfg(feature = "async")]
324pub use async_feature::*;