duvet_core/
query.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use futures::{future::poll_fn, ready, FutureExt};
5use std::{
6    cell::UnsafeCell,
7    fmt,
8    future::{Future, Pending},
9    mem::MaybeUninit,
10    pin::Pin,
11    sync::{
12        atomic::{AtomicBool, Ordering},
13        Arc,
14    },
15    task::{Context, Poll},
16};
17use tokio::sync::Semaphore;
18
19pub struct Query<T> {
20    inner: Arc<dyn InnerState<Output = T>>,
21}
22
23impl<T> Clone for Query<T> {
24    fn clone(&self) -> Self {
25        Self {
26            inner: self.inner.clone(),
27        }
28    }
29}
30
31impl<T: 'static + Send + Sync + fmt::Debug> fmt::Debug for Query<T> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        f.debug_tuple("Query").field(&self.try_get()).finish()
34    }
35}
36
37unsafe impl<T: Sync + Send> Sync for Query<T> {}
38unsafe impl<T: Sync + Send> Send for Query<T> {}
39
40impl<T: 'static + Send + Sync> From<T> for Query<T> {
41    fn from(value: T) -> Self {
42        let semaphore = Semaphore::new(0);
43        semaphore.close();
44
45        let future = UnsafeCell::new(FutureState::<Pending<T>>::Finished);
46
47        let inner = Inner {
48            value_set: AtomicBool::new(true),
49            value: UnsafeCell::new(MaybeUninit::new(value)),
50            semaphore,
51            future,
52        };
53
54        Query {
55            inner: Arc::new(inner),
56        }
57    }
58}
59
60impl<T: 'static + Send + Sync> Query<T> {
61    pub fn new<F: 'static + Future<Output = T> + Send>(future: F) -> Self {
62        let inner = Inner {
63            value_set: AtomicBool::new(false),
64            value: UnsafeCell::new(MaybeUninit::uninit()),
65            semaphore: Semaphore::new(1),
66            future: UnsafeCell::new(FutureState::Init(future)),
67        };
68
69        Self {
70            inner: Arc::new(inner),
71        }
72    }
73
74    pub fn delegate<F: 'static + Future<Output = Query<T>> + Send>(future: F) -> Self {
75        let inner = Inner {
76            value_set: AtomicBool::new(false),
77            value: UnsafeCell::new(MaybeUninit::uninit()),
78            semaphore: Semaphore::new(1),
79            future: UnsafeCell::new(FutureState::Init(future)),
80        };
81
82        let inner = Delegate {
83            inner,
84            query_fut: UnsafeCell::new(None),
85        };
86
87        Self {
88            inner: Arc::new(inner),
89        }
90    }
91
92    pub fn spawn<F: 'static + Future<Output = T> + Send>(future: F) -> Self {
93        let inner = Spawn {
94            value_set: AtomicBool::new(false),
95            value: UnsafeCell::new(MaybeUninit::uninit()),
96            semaphore: Semaphore::new(1),
97            future: UnsafeCell::new(SpawnFutureState::Init(future)),
98        };
99
100        Self {
101            inner: Arc::new(inner),
102        }
103    }
104
105    pub async fn get(&self) -> &T {
106        if let Some(value) = self.try_get() {
107            return value;
108        }
109
110        // Here we try to acquire the semaphore permit. Holding the permit
111        // will allow us to set the value of the Query, and prevents
112        // other tasks from initializing the Query while we are holding
113        // it.
114        if let Ok(permit) = self.inner.semaphore().acquire().await {
115            debug_assert!(!self.inner.initialized());
116
117            // If `f()` panics or `select!` is called, this
118            // `get_or_init` call is aborted and the semaphore permit is
119            // dropped.
120            poll_fn(move |cx| unsafe {
121                // SAFETY: polling is guarded by semaphores
122                self.inner.poll(cx)
123            })
124            .await;
125
126            permit.forget();
127        }
128
129        // SAFETY: The semaphore has been closed. This only happens
130        // when the Query is fully initialized.
131        unsafe { self.inner.get_unchecked() }
132    }
133
134    pub fn try_get(&self) -> Option<&T> {
135        self.inner.try_get()
136    }
137
138    pub fn map<M, F, R>(&self, m: M) -> Query<R>
139    where
140        M: 'static + Send + FnOnce(&T) -> F,
141        F: 'static + Send + Future<Output = R>,
142        R: 'static + Send + Sync,
143    {
144        let inner = self.clone();
145        Query::new(async move {
146            let v = inner.get().await;
147            m(v).await
148        })
149    }
150}
151
152impl<T: 'static + Clone + Send + Sync> Query<T> {
153    pub async fn get_cloned(self) -> T {
154        let value = self.get().await;
155        value.clone()
156    }
157
158    pub fn map_cloned<M, F, R>(&self, m: M) -> Query<R>
159    where
160        M: 'static + Send + FnOnce(T) -> F,
161        F: 'static + Send + Future<Output = R>,
162        R: 'static + Send + Sync,
163    {
164        let inner = self.clone();
165        Query::new(async move {
166            let v = inner.get_cloned().await;
167            m(v).await
168        })
169    }
170}
171
172impl<T> core::future::IntoFuture for Query<T>
173where
174    T: 'static + Clone + Send + Sync,
175{
176    type Output = T;
177    type IntoFuture = Pin<Box<dyn 'static + Send + Future<Output = T>>>;
178
179    fn into_future(self) -> Self::IntoFuture {
180        Box::pin(self.get_cloned())
181    }
182}
183
184trait InnerState: Send + Sync {
185    type Output;
186
187    fn try_get(&self) -> Option<&Self::Output>;
188    unsafe fn get_unchecked(&self) -> &Self::Output;
189    fn initialized(&self) -> bool;
190    fn semaphore(&self) -> &Semaphore;
191    unsafe fn poll(&self, cx: &mut Context) -> Poll<()>;
192}
193
194struct Inner<T, F> {
195    value_set: AtomicBool,
196    value: UnsafeCell<MaybeUninit<T>>,
197    semaphore: Semaphore,
198    future: UnsafeCell<FutureState<F>>,
199}
200
201unsafe impl<T: Sync + Send, F: Send> Sync for Inner<T, F> {}
202unsafe impl<T: Sync + Send, F: Send> Send for Inner<T, F> {}
203
204impl<T, F> Inner<T, F> {
205    fn initialized(&self) -> bool {
206        // Using acquire ordering so any threads that read a true from this
207        // atomic is able to read the value.
208        self.value_set.load(Ordering::Acquire)
209    }
210}
211
212impl<T, F> Drop for Inner<T, F> {
213    fn drop(&mut self) {
214        if self.initialized() {
215            unsafe {
216                (*self.value.get()).assume_init_drop();
217            }
218        }
219    }
220}
221
222impl<T, F> InnerState for Inner<T, F>
223where
224    T: Send + Sync,
225    F: Send + Future<Output = T>,
226{
227    type Output = T;
228
229    unsafe fn get_unchecked(&self) -> &Self::Output {
230        debug_assert!(self.initialized());
231
232        (*self.value.get()).assume_init_ref()
233    }
234
235    fn try_get(&self) -> Option<&T> {
236        if self.initialized() {
237            // SAFETY: The Query has been fully initialized.
238            Some(unsafe { self.get_unchecked() })
239        } else {
240            None
241        }
242    }
243
244    fn initialized(&self) -> bool {
245        Inner::initialized(self)
246    }
247
248    fn semaphore(&self) -> &Semaphore {
249        &self.semaphore
250    }
251
252    unsafe fn poll(&self, cx: &mut Context) -> Poll<()> {
253        let future = &mut *self.future.get();
254        match future {
255            FutureState::Init(future) => {
256                // the Inner is allocated and is stable
257                let future = Pin::new_unchecked(future);
258                let value = ready!(future.poll(cx));
259
260                self.value.get().write(MaybeUninit::new(value));
261                self.future.get().write(FutureState::Finished);
262
263                // Using release ordering so any threads that read a true from this
264                // atomic is able to read the value we just stored.
265                self.value_set.store(true, Ordering::Release);
266                self.semaphore.close();
267
268                Poll::Ready(())
269            }
270            FutureState::Finished => {
271                debug_assert!(self.initialized());
272                Poll::Ready(())
273            }
274        }
275    }
276}
277
278struct Delegate<T, F> {
279    inner: Inner<Query<T>, F>,
280    query_fut: UnsafeCell<Option<Pin<Box<dyn Future<Output = ()>>>>>,
281}
282
283unsafe impl<T: Sync + Send, F: Send> Sync for Delegate<T, F> {}
284unsafe impl<T: Sync + Send, F: Send> Send for Delegate<T, F> {}
285
286impl<T, F> InnerState for Delegate<T, F>
287where
288    T: 'static + Send + Sync,
289    F: Send + Future<Output = Query<T>>,
290{
291    type Output = T;
292
293    unsafe fn get_unchecked(&self) -> &Self::Output {
294        self.inner.get_unchecked().inner.get_unchecked()
295    }
296
297    fn try_get(&self) -> Option<&T> {
298        if self.initialized() {
299            // SAFETY: The Query has been fully initialized.
300            let query = unsafe { self.inner.get_unchecked() };
301            if query.inner.initialized() {
302                Some(unsafe { query.inner.get_unchecked() })
303            } else {
304                None
305            }
306        } else {
307            None
308        }
309    }
310
311    fn initialized(&self) -> bool {
312        self.inner.initialized()
313    }
314
315    fn semaphore(&self) -> &Semaphore {
316        &self.inner.semaphore
317    }
318
319    unsafe fn poll(&self, cx: &mut Context) -> Poll<()> {
320        loop {
321            if let Some(ref mut f) = &mut *self.query_fut.get() {
322                ready!(f.poll_unpin(cx));
323
324                *self.query_fut.get() = None;
325                return Poll::Ready(());
326            }
327
328            ready!(self.inner.poll(cx));
329
330            let query = self.inner.get_unchecked().clone();
331            *self.query_fut.get() = Some(Box::pin(async move {
332                // evaluate the nested query
333                query.get().await;
334            }));
335        }
336    }
337}
338
339struct Spawn<T, F> {
340    value_set: AtomicBool,
341    value: UnsafeCell<MaybeUninit<T>>,
342    semaphore: Semaphore,
343    future: UnsafeCell<SpawnFutureState<T, F>>,
344}
345
346unsafe impl<T: Sync + Send, F: Send> Sync for Spawn<T, F> {}
347unsafe impl<T: Sync + Send, F: Send> Send for Spawn<T, F> {}
348
349impl<T, F> InnerState for Spawn<T, F>
350where
351    T: 'static + Send + Sync,
352    F: 'static + Send + Future<Output = T>,
353{
354    type Output = T;
355
356    unsafe fn get_unchecked(&self) -> &Self::Output {
357        debug_assert!(self.initialized());
358
359        (*self.value.get()).assume_init_ref()
360    }
361
362    fn try_get(&self) -> Option<&T> {
363        if self.initialized() {
364            // SAFETY: The Query has been fully initialized.
365            Some(unsafe { self.get_unchecked() })
366        } else {
367            None
368        }
369    }
370
371    fn initialized(&self) -> bool {
372        // Using acquire ordering so any threads that read a true from this
373        // atomic is able to read the value.
374        self.value_set.load(Ordering::Acquire)
375    }
376
377    fn semaphore(&self) -> &Semaphore {
378        &self.semaphore
379    }
380
381    unsafe fn poll(&self, cx: &mut Context) -> Poll<()> {
382        let future = &mut *self.future.get();
383        loop {
384            match core::mem::replace(future, SpawnFutureState::Finished) {
385                SpawnFutureState::Init(fut) => {
386                    let handle = tokio::spawn(fut);
387                    *future = SpawnFutureState::Spawned(handle);
388                }
389                SpawnFutureState::Spawned(mut handle) => {
390                    let value = match Pin::new(&mut handle).poll(cx) {
391                        Poll::Ready(value) => value,
392                        Poll::Pending => {
393                            *future = SpawnFutureState::Spawned(handle);
394                            return Poll::Pending;
395                        }
396                    };
397
398                    return match value {
399                        Ok(value) => {
400                            self.value.get().write(MaybeUninit::new(value));
401                            self.future.get().write(SpawnFutureState::Spawned(handle));
402
403                            // Using release ordering so any threads that read a true from this
404                            // atomic is able to read the value we just stored.
405                            self.value_set.store(true, Ordering::Release);
406                            self.semaphore.close();
407
408                            Poll::Ready(())
409                        }
410                        Err(err) => match err.try_into_panic() {
411                            Ok(reason) => std::panic::resume_unwind(reason),
412                            Err(err) => panic!("{}", err),
413                        },
414                    };
415                }
416                SpawnFutureState::Finished => {
417                    debug_assert!(self.initialized());
418                    return Poll::Ready(());
419                }
420            }
421        }
422    }
423}
424
425enum FutureState<F> {
426    Init(F),
427    Finished,
428}
429
430enum SpawnFutureState<T, F> {
431    Init(F),
432    Spawned(tokio::task::JoinHandle<T>),
433    Finished,
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use tokio::sync::oneshot;
440
441    #[tokio::test]
442    async fn query_test() {
443        let (tx, rx) = oneshot::channel::<u64>();
444
445        let query = Query::new(async move { rx.await.unwrap() });
446
447        let a = query.clone();
448        let a = async move { *a.get().await };
449
450        let b = query;
451        let b = async move { *b.get().await };
452
453        tx.send(123).unwrap();
454
455        let (a, b) = tokio::join!(a, b);
456
457        assert_eq!(a, 123);
458        assert_eq!(b, 123);
459    }
460}