Skip to main content

deno_unsync/
future.rs

1// Copyright 2018-2024 the Deno authors. MIT license.
2
3use parking_lot::Mutex;
4use std::cell::RefCell;
5use std::future::Future;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::sync::Arc;
9use std::task::Context;
10use std::task::Wake;
11use std::task::Waker;
12
13use crate::sync::AtomicFlag;
14
15impl<T: ?Sized> LocalFutureExt for T where T: Future {}
16
17pub trait LocalFutureExt: std::future::Future {
18  fn shared_local(self) -> SharedLocal<Self>
19  where
20    Self: Sized,
21    Self::Output: Clone,
22  {
23    SharedLocal::new(self)
24  }
25}
26
27enum FutureOrOutput<TFuture: Future> {
28  Future(TFuture),
29  Output(TFuture::Output),
30}
31
32impl<TFuture: Future> std::fmt::Debug for FutureOrOutput<TFuture>
33where
34  TFuture::Output: std::fmt::Debug,
35{
36  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37    match self {
38      Self::Future(_) => f.debug_tuple("Future").field(&"<pending>").finish(),
39      Self::Output(arg0) => f.debug_tuple("Result").field(arg0).finish(),
40    }
41  }
42}
43
44struct SharedLocalData<TFuture: Future> {
45  future_or_output: FutureOrOutput<TFuture>,
46}
47
48impl<TFuture: Future> std::fmt::Debug for SharedLocalData<TFuture>
49where
50  TFuture::Output: std::fmt::Debug,
51{
52  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53    f.debug_struct("SharedLocalData")
54      .field("future_or_output", &self.future_or_output)
55      .finish()
56  }
57}
58
59struct SharedLocalInner<TFuture: Future> {
60  data: RefCell<SharedLocalData<TFuture>>,
61  child_waker_state: Arc<ChildWakerState>,
62}
63
64impl<TFuture: Future> std::fmt::Debug for SharedLocalInner<TFuture>
65where
66  TFuture::Output: std::fmt::Debug,
67{
68  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69    f.debug_struct("SharedLocalInner")
70      .field("data", &self.data)
71      .field("child_waker_state", &self.child_waker_state)
72      .finish()
73  }
74}
75
76/// A !Send-friendly future whose result can be awaited multiple times.
77#[must_use = "futures do nothing unless you `.await` or poll them"]
78pub struct SharedLocal<TFuture: Future>(Rc<SharedLocalInner<TFuture>>);
79
80impl<TFuture: Future> Clone for SharedLocal<TFuture>
81where
82  TFuture::Output: Clone,
83{
84  fn clone(&self) -> Self {
85    Self(self.0.clone())
86  }
87}
88
89impl<TFuture: Future> std::fmt::Debug for SharedLocal<TFuture>
90where
91  TFuture::Output: std::fmt::Debug,
92{
93  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94    f.debug_tuple("SharedLocal").field(&self.0).finish()
95  }
96}
97
98impl<TFuture: Future> SharedLocal<TFuture>
99where
100  TFuture::Output: Clone,
101{
102  pub fn new(future: TFuture) -> Self {
103    SharedLocal(Rc::new(SharedLocalInner {
104      data: RefCell::new(SharedLocalData {
105        future_or_output: FutureOrOutput::Future(future),
106      }),
107      child_waker_state: Arc::new(ChildWakerState {
108        can_poll: AtomicFlag::raised(),
109        wakers: Default::default(),
110      }),
111    }))
112  }
113}
114
115impl<TFuture: Future> std::future::Future for SharedLocal<TFuture>
116where
117  TFuture::Output: Clone,
118{
119  type Output = TFuture::Output;
120
121  fn poll(
122    self: std::pin::Pin<&mut Self>,
123    cx: &mut std::task::Context<'_>,
124  ) -> std::task::Poll<Self::Output> {
125    use std::task::Poll;
126
127    let mut inner = self.0.data.borrow_mut();
128    match &mut inner.future_or_output {
129      FutureOrOutput::Future(fut) => {
130        self.0.child_waker_state.wakers.push(cx.waker().clone());
131        if self.0.child_waker_state.can_poll.lower() {
132          let child_waker = Waker::from(self.0.child_waker_state.clone());
133          let mut child_cx = Context::from_waker(&child_waker);
134          let fut = unsafe { Pin::new_unchecked(fut) };
135          match fut.poll(&mut child_cx) {
136            Poll::Ready(result) => {
137              inner.future_or_output = FutureOrOutput::Output(result.clone());
138              drop(inner); // stop borrow_mut
139              let wakers = self.0.child_waker_state.wakers.take_all();
140              for waker in wakers {
141                waker.wake();
142              }
143              Poll::Ready(result)
144            }
145            Poll::Pending => Poll::Pending,
146          }
147        } else {
148          Poll::Pending
149        }
150      }
151      FutureOrOutput::Output(result) => Poll::Ready(result.clone()),
152    }
153  }
154}
155
156#[derive(Debug, Default)]
157struct WakerStore(Mutex<Vec<Waker>>);
158
159impl WakerStore {
160  pub fn take_all(&self) -> Vec<Waker> {
161    let mut wakers = self.0.lock();
162    std::mem::take(&mut *wakers)
163  }
164
165  pub fn clone_all(&self) -> Vec<Waker> {
166    self.0.lock().clone()
167  }
168
169  pub fn push(&self, waker: Waker) {
170    self.0.lock().push(waker);
171  }
172}
173
174#[derive(Debug)]
175struct ChildWakerState {
176  can_poll: AtomicFlag,
177  wakers: WakerStore,
178}
179
180impl Wake for ChildWakerState {
181  fn wake(self: Arc<Self>) {
182    self.can_poll.raise();
183    let wakers = self.wakers.take_all();
184
185    for waker in wakers {
186      waker.wake();
187    }
188  }
189
190  fn wake_by_ref(self: &Arc<Self>) {
191    self.can_poll.raise();
192    let wakers = self.wakers.clone_all();
193
194    for waker in wakers {
195      waker.wake_by_ref();
196    }
197  }
198}
199
200#[cfg(test)]
201mod test {
202  use std::sync::Arc;
203
204  use tokio::sync::Notify;
205
206  use super::LocalFutureExt;
207
208  #[tokio::test(flavor = "current_thread")]
209  async fn test_shared_local_future() {
210    let shared = super::SharedLocal::new(Box::pin(async { 42 }));
211    assert_eq!(shared.clone().await, 42);
212    assert_eq!(shared.await, 42);
213  }
214
215  #[tokio::test(flavor = "current_thread")]
216  async fn test_shared_local() {
217    let shared = async { 42 }.shared_local();
218    assert_eq!(shared.clone().await, 42);
219    assert_eq!(shared.await, 42);
220  }
221
222  #[tokio::test(flavor = "current_thread")]
223  async fn multiple_tasks_waiting() {
224    let notify = Arc::new(Notify::new());
225
226    let shared = {
227      let notify = notify.clone();
228      async move {
229        tokio::task::yield_now().await;
230        notify.notified().await;
231        tokio::task::yield_now().await;
232        tokio::task::yield_now().await;
233      }
234      .shared_local()
235    };
236    let mut tasks = Vec::new();
237    for _ in 0..10 {
238      tasks.push(crate::spawn(shared.clone()));
239    }
240
241    crate::spawn(async move {
242      notify.notify_one();
243      for task in tasks {
244        task.await.unwrap();
245      }
246    })
247    .await
248    .unwrap()
249  }
250}