1use futures::{future::poll_fn, ready, FutureExt};
5use std::{
6 cell::UnsafeCell,
7 fmt,
8 future::{Future, Pending},
9 mem::MaybeUninit,
10 pin::Pin,
11 sync::{
12 atomic::{AtomicBool, Ordering},
13 Arc,
14 },
15 task::{Context, Poll},
16};
17use tokio::sync::Semaphore;
18
19pub struct Query<T> {
20 inner: Arc<dyn InnerState<Output = T>>,
21}
22
23impl<T> Clone for Query<T> {
24 fn clone(&self) -> Self {
25 Self {
26 inner: self.inner.clone(),
27 }
28 }
29}
30
31impl<T: 'static + Send + Sync + fmt::Debug> fmt::Debug for Query<T> {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 f.debug_tuple("Query").field(&self.try_get()).finish()
34 }
35}
36
37unsafe impl<T: Sync + Send> Sync for Query<T> {}
38unsafe impl<T: Sync + Send> Send for Query<T> {}
39
40impl<T: 'static + Send + Sync> From<T> for Query<T> {
41 fn from(value: T) -> Self {
42 let semaphore = Semaphore::new(0);
43 semaphore.close();
44
45 let future = UnsafeCell::new(FutureState::<Pending<T>>::Finished);
46
47 let inner = Inner {
48 value_set: AtomicBool::new(true),
49 value: UnsafeCell::new(MaybeUninit::new(value)),
50 semaphore,
51 future,
52 };
53
54 Query {
55 inner: Arc::new(inner),
56 }
57 }
58}
59
60impl<T: 'static + Send + Sync> Query<T> {
61 pub fn new<F: 'static + Future<Output = T> + Send>(future: F) -> Self {
62 let inner = Inner {
63 value_set: AtomicBool::new(false),
64 value: UnsafeCell::new(MaybeUninit::uninit()),
65 semaphore: Semaphore::new(1),
66 future: UnsafeCell::new(FutureState::Init(future)),
67 };
68
69 Self {
70 inner: Arc::new(inner),
71 }
72 }
73
74 pub fn delegate<F: 'static + Future<Output = Query<T>> + Send>(future: F) -> Self {
75 let inner = Inner {
76 value_set: AtomicBool::new(false),
77 value: UnsafeCell::new(MaybeUninit::uninit()),
78 semaphore: Semaphore::new(1),
79 future: UnsafeCell::new(FutureState::Init(future)),
80 };
81
82 let inner = Delegate {
83 inner,
84 query_fut: UnsafeCell::new(None),
85 };
86
87 Self {
88 inner: Arc::new(inner),
89 }
90 }
91
92 pub fn spawn<F: 'static + Future<Output = T> + Send>(future: F) -> Self {
93 let inner = Spawn {
94 value_set: AtomicBool::new(false),
95 value: UnsafeCell::new(MaybeUninit::uninit()),
96 semaphore: Semaphore::new(1),
97 future: UnsafeCell::new(SpawnFutureState::Init(future)),
98 };
99
100 Self {
101 inner: Arc::new(inner),
102 }
103 }
104
105 pub async fn get(&self) -> &T {
106 if let Some(value) = self.try_get() {
107 return value;
108 }
109
110 if let Ok(permit) = self.inner.semaphore().acquire().await {
115 debug_assert!(!self.inner.initialized());
116
117 poll_fn(move |cx| unsafe {
121 self.inner.poll(cx)
123 })
124 .await;
125
126 permit.forget();
127 }
128
129 unsafe { self.inner.get_unchecked() }
132 }
133
134 pub fn try_get(&self) -> Option<&T> {
135 self.inner.try_get()
136 }
137
138 pub fn map<M, F, R>(&self, m: M) -> Query<R>
139 where
140 M: 'static + Send + FnOnce(&T) -> F,
141 F: 'static + Send + Future<Output = R>,
142 R: 'static + Send + Sync,
143 {
144 let inner = self.clone();
145 Query::new(async move {
146 let v = inner.get().await;
147 m(v).await
148 })
149 }
150}
151
152impl<T: 'static + Clone + Send + Sync> Query<T> {
153 pub async fn get_cloned(self) -> T {
154 let value = self.get().await;
155 value.clone()
156 }
157
158 pub fn map_cloned<M, F, R>(&self, m: M) -> Query<R>
159 where
160 M: 'static + Send + FnOnce(T) -> F,
161 F: 'static + Send + Future<Output = R>,
162 R: 'static + Send + Sync,
163 {
164 let inner = self.clone();
165 Query::new(async move {
166 let v = inner.get_cloned().await;
167 m(v).await
168 })
169 }
170}
171
172impl<T> core::future::IntoFuture for Query<T>
173where
174 T: 'static + Clone + Send + Sync,
175{
176 type Output = T;
177 type IntoFuture = Pin<Box<dyn 'static + Send + Future<Output = T>>>;
178
179 fn into_future(self) -> Self::IntoFuture {
180 Box::pin(self.get_cloned())
181 }
182}
183
184trait InnerState: Send + Sync {
185 type Output;
186
187 fn try_get(&self) -> Option<&Self::Output>;
188 unsafe fn get_unchecked(&self) -> &Self::Output;
189 fn initialized(&self) -> bool;
190 fn semaphore(&self) -> &Semaphore;
191 unsafe fn poll(&self, cx: &mut Context) -> Poll<()>;
192}
193
194struct Inner<T, F> {
195 value_set: AtomicBool,
196 value: UnsafeCell<MaybeUninit<T>>,
197 semaphore: Semaphore,
198 future: UnsafeCell<FutureState<F>>,
199}
200
201unsafe impl<T: Sync + Send, F: Send> Sync for Inner<T, F> {}
202unsafe impl<T: Sync + Send, F: Send> Send for Inner<T, F> {}
203
204impl<T, F> Inner<T, F> {
205 fn initialized(&self) -> bool {
206 self.value_set.load(Ordering::Acquire)
209 }
210}
211
212impl<T, F> Drop for Inner<T, F> {
213 fn drop(&mut self) {
214 if self.initialized() {
215 unsafe {
216 (*self.value.get()).assume_init_drop();
217 }
218 }
219 }
220}
221
222impl<T, F> InnerState for Inner<T, F>
223where
224 T: Send + Sync,
225 F: Send + Future<Output = T>,
226{
227 type Output = T;
228
229 unsafe fn get_unchecked(&self) -> &Self::Output {
230 debug_assert!(self.initialized());
231
232 (*self.value.get()).assume_init_ref()
233 }
234
235 fn try_get(&self) -> Option<&T> {
236 if self.initialized() {
237 Some(unsafe { self.get_unchecked() })
239 } else {
240 None
241 }
242 }
243
244 fn initialized(&self) -> bool {
245 Inner::initialized(self)
246 }
247
248 fn semaphore(&self) -> &Semaphore {
249 &self.semaphore
250 }
251
252 unsafe fn poll(&self, cx: &mut Context) -> Poll<()> {
253 let future = &mut *self.future.get();
254 match future {
255 FutureState::Init(future) => {
256 let future = Pin::new_unchecked(future);
258 let value = ready!(future.poll(cx));
259
260 self.value.get().write(MaybeUninit::new(value));
261 self.future.get().write(FutureState::Finished);
262
263 self.value_set.store(true, Ordering::Release);
266 self.semaphore.close();
267
268 Poll::Ready(())
269 }
270 FutureState::Finished => {
271 debug_assert!(self.initialized());
272 Poll::Ready(())
273 }
274 }
275 }
276}
277
278struct Delegate<T, F> {
279 inner: Inner<Query<T>, F>,
280 query_fut: UnsafeCell<Option<Pin<Box<dyn Future<Output = ()>>>>>,
281}
282
283unsafe impl<T: Sync + Send, F: Send> Sync for Delegate<T, F> {}
284unsafe impl<T: Sync + Send, F: Send> Send for Delegate<T, F> {}
285
286impl<T, F> InnerState for Delegate<T, F>
287where
288 T: 'static + Send + Sync,
289 F: Send + Future<Output = Query<T>>,
290{
291 type Output = T;
292
293 unsafe fn get_unchecked(&self) -> &Self::Output {
294 self.inner.get_unchecked().inner.get_unchecked()
295 }
296
297 fn try_get(&self) -> Option<&T> {
298 if self.initialized() {
299 let query = unsafe { self.inner.get_unchecked() };
301 if query.inner.initialized() {
302 Some(unsafe { query.inner.get_unchecked() })
303 } else {
304 None
305 }
306 } else {
307 None
308 }
309 }
310
311 fn initialized(&self) -> bool {
312 self.inner.initialized()
313 }
314
315 fn semaphore(&self) -> &Semaphore {
316 &self.inner.semaphore
317 }
318
319 unsafe fn poll(&self, cx: &mut Context) -> Poll<()> {
320 loop {
321 if let Some(ref mut f) = &mut *self.query_fut.get() {
322 ready!(f.poll_unpin(cx));
323
324 *self.query_fut.get() = None;
325 return Poll::Ready(());
326 }
327
328 ready!(self.inner.poll(cx));
329
330 let query = self.inner.get_unchecked().clone();
331 *self.query_fut.get() = Some(Box::pin(async move {
332 query.get().await;
334 }));
335 }
336 }
337}
338
339struct Spawn<T, F> {
340 value_set: AtomicBool,
341 value: UnsafeCell<MaybeUninit<T>>,
342 semaphore: Semaphore,
343 future: UnsafeCell<SpawnFutureState<T, F>>,
344}
345
346unsafe impl<T: Sync + Send, F: Send> Sync for Spawn<T, F> {}
347unsafe impl<T: Sync + Send, F: Send> Send for Spawn<T, F> {}
348
349impl<T, F> InnerState for Spawn<T, F>
350where
351 T: 'static + Send + Sync,
352 F: 'static + Send + Future<Output = T>,
353{
354 type Output = T;
355
356 unsafe fn get_unchecked(&self) -> &Self::Output {
357 debug_assert!(self.initialized());
358
359 (*self.value.get()).assume_init_ref()
360 }
361
362 fn try_get(&self) -> Option<&T> {
363 if self.initialized() {
364 Some(unsafe { self.get_unchecked() })
366 } else {
367 None
368 }
369 }
370
371 fn initialized(&self) -> bool {
372 self.value_set.load(Ordering::Acquire)
375 }
376
377 fn semaphore(&self) -> &Semaphore {
378 &self.semaphore
379 }
380
381 unsafe fn poll(&self, cx: &mut Context) -> Poll<()> {
382 let future = &mut *self.future.get();
383 loop {
384 match core::mem::replace(future, SpawnFutureState::Finished) {
385 SpawnFutureState::Init(fut) => {
386 let handle = tokio::spawn(fut);
387 *future = SpawnFutureState::Spawned(handle);
388 }
389 SpawnFutureState::Spawned(mut handle) => {
390 let value = match Pin::new(&mut handle).poll(cx) {
391 Poll::Ready(value) => value,
392 Poll::Pending => {
393 *future = SpawnFutureState::Spawned(handle);
394 return Poll::Pending;
395 }
396 };
397
398 return match value {
399 Ok(value) => {
400 self.value.get().write(MaybeUninit::new(value));
401 self.future.get().write(SpawnFutureState::Spawned(handle));
402
403 self.value_set.store(true, Ordering::Release);
406 self.semaphore.close();
407
408 Poll::Ready(())
409 }
410 Err(err) => match err.try_into_panic() {
411 Ok(reason) => std::panic::resume_unwind(reason),
412 Err(err) => panic!("{}", err),
413 },
414 };
415 }
416 SpawnFutureState::Finished => {
417 debug_assert!(self.initialized());
418 return Poll::Ready(());
419 }
420 }
421 }
422 }
423}
424
425enum FutureState<F> {
426 Init(F),
427 Finished,
428}
429
430enum SpawnFutureState<T, F> {
431 Init(F),
432 Spawned(tokio::task::JoinHandle<T>),
433 Finished,
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use tokio::sync::oneshot;
440
441 #[tokio::test]
442 async fn query_test() {
443 let (tx, rx) = oneshot::channel::<u64>();
444
445 let query = Query::new(async move { rx.await.unwrap() });
446
447 let a = query.clone();
448 let a = async move { *a.get().await };
449
450 let b = query;
451 let b = async move { *b.get().await };
452
453 tx.send(123).unwrap();
454
455 let (a, b) = tokio::join!(a, b);
456
457 assert_eq!(a, 123);
458 assert_eq!(b, 123);
459 }
460}