1use parking_lot::Mutex;
4use std::cell::RefCell;
5use std::future::Future;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::sync::Arc;
9use std::task::Context;
10use std::task::Wake;
11use std::task::Waker;
12
13use crate::sync::AtomicFlag;
14
15impl<T: ?Sized> LocalFutureExt for T where T: Future {}
16
17pub trait LocalFutureExt: std::future::Future {
18 fn shared_local(self) -> SharedLocal<Self>
19 where
20 Self: Sized,
21 Self::Output: Clone,
22 {
23 SharedLocal::new(self)
24 }
25}
26
27enum FutureOrOutput<TFuture: Future> {
28 Future(TFuture),
29 Output(TFuture::Output),
30}
31
32impl<TFuture: Future> std::fmt::Debug for FutureOrOutput<TFuture>
33where
34 TFuture::Output: std::fmt::Debug,
35{
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Self::Future(_) => f.debug_tuple("Future").field(&"<pending>").finish(),
39 Self::Output(arg0) => f.debug_tuple("Result").field(arg0).finish(),
40 }
41 }
42}
43
44struct SharedLocalData<TFuture: Future> {
45 future_or_output: FutureOrOutput<TFuture>,
46}
47
48impl<TFuture: Future> std::fmt::Debug for SharedLocalData<TFuture>
49where
50 TFuture::Output: std::fmt::Debug,
51{
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("SharedLocalData")
54 .field("future_or_output", &self.future_or_output)
55 .finish()
56 }
57}
58
59struct SharedLocalInner<TFuture: Future> {
60 data: RefCell<SharedLocalData<TFuture>>,
61 child_waker_state: Arc<ChildWakerState>,
62}
63
64impl<TFuture: Future> std::fmt::Debug for SharedLocalInner<TFuture>
65where
66 TFuture::Output: std::fmt::Debug,
67{
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("SharedLocalInner")
70 .field("data", &self.data)
71 .field("child_waker_state", &self.child_waker_state)
72 .finish()
73 }
74}
75
76#[must_use = "futures do nothing unless you `.await` or poll them"]
78pub struct SharedLocal<TFuture: Future>(Rc<SharedLocalInner<TFuture>>);
79
80impl<TFuture: Future> Clone for SharedLocal<TFuture>
81where
82 TFuture::Output: Clone,
83{
84 fn clone(&self) -> Self {
85 Self(self.0.clone())
86 }
87}
88
89impl<TFuture: Future> std::fmt::Debug for SharedLocal<TFuture>
90where
91 TFuture::Output: std::fmt::Debug,
92{
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_tuple("SharedLocal").field(&self.0).finish()
95 }
96}
97
98impl<TFuture: Future> SharedLocal<TFuture>
99where
100 TFuture::Output: Clone,
101{
102 pub fn new(future: TFuture) -> Self {
103 SharedLocal(Rc::new(SharedLocalInner {
104 data: RefCell::new(SharedLocalData {
105 future_or_output: FutureOrOutput::Future(future),
106 }),
107 child_waker_state: Arc::new(ChildWakerState {
108 can_poll: AtomicFlag::raised(),
109 wakers: Default::default(),
110 }),
111 }))
112 }
113}
114
115impl<TFuture: Future> std::future::Future for SharedLocal<TFuture>
116where
117 TFuture::Output: Clone,
118{
119 type Output = TFuture::Output;
120
121 fn poll(
122 self: std::pin::Pin<&mut Self>,
123 cx: &mut std::task::Context<'_>,
124 ) -> std::task::Poll<Self::Output> {
125 use std::task::Poll;
126
127 let mut inner = self.0.data.borrow_mut();
128 match &mut inner.future_or_output {
129 FutureOrOutput::Future(fut) => {
130 self.0.child_waker_state.wakers.push(cx.waker().clone());
131 if self.0.child_waker_state.can_poll.lower() {
132 let child_waker = Waker::from(self.0.child_waker_state.clone());
133 let mut child_cx = Context::from_waker(&child_waker);
134 let fut = unsafe { Pin::new_unchecked(fut) };
135 match fut.poll(&mut child_cx) {
136 Poll::Ready(result) => {
137 inner.future_or_output = FutureOrOutput::Output(result.clone());
138 drop(inner); let wakers = self.0.child_waker_state.wakers.take_all();
140 for waker in wakers {
141 waker.wake();
142 }
143 Poll::Ready(result)
144 }
145 Poll::Pending => Poll::Pending,
146 }
147 } else {
148 Poll::Pending
149 }
150 }
151 FutureOrOutput::Output(result) => Poll::Ready(result.clone()),
152 }
153 }
154}
155
156#[derive(Debug, Default)]
157struct WakerStore(Mutex<Vec<Waker>>);
158
159impl WakerStore {
160 pub fn take_all(&self) -> Vec<Waker> {
161 let mut wakers = self.0.lock();
162 std::mem::take(&mut *wakers)
163 }
164
165 pub fn clone_all(&self) -> Vec<Waker> {
166 self.0.lock().clone()
167 }
168
169 pub fn push(&self, waker: Waker) {
170 self.0.lock().push(waker);
171 }
172}
173
174#[derive(Debug)]
175struct ChildWakerState {
176 can_poll: AtomicFlag,
177 wakers: WakerStore,
178}
179
180impl Wake for ChildWakerState {
181 fn wake(self: Arc<Self>) {
182 self.can_poll.raise();
183 let wakers = self.wakers.take_all();
184
185 for waker in wakers {
186 waker.wake();
187 }
188 }
189
190 fn wake_by_ref(self: &Arc<Self>) {
191 self.can_poll.raise();
192 let wakers = self.wakers.clone_all();
193
194 for waker in wakers {
195 waker.wake_by_ref();
196 }
197 }
198}
199
200#[cfg(test)]
201mod test {
202 use std::sync::Arc;
203
204 use tokio::sync::Notify;
205
206 use super::LocalFutureExt;
207
208 #[tokio::test(flavor = "current_thread")]
209 async fn test_shared_local_future() {
210 let shared = super::SharedLocal::new(Box::pin(async { 42 }));
211 assert_eq!(shared.clone().await, 42);
212 assert_eq!(shared.await, 42);
213 }
214
215 #[tokio::test(flavor = "current_thread")]
216 async fn test_shared_local() {
217 let shared = async { 42 }.shared_local();
218 assert_eq!(shared.clone().await, 42);
219 assert_eq!(shared.await, 42);
220 }
221
222 #[tokio::test(flavor = "current_thread")]
223 async fn multiple_tasks_waiting() {
224 let notify = Arc::new(Notify::new());
225
226 let shared = {
227 let notify = notify.clone();
228 async move {
229 tokio::task::yield_now().await;
230 notify.notified().await;
231 tokio::task::yield_now().await;
232 tokio::task::yield_now().await;
233 }
234 .shared_local()
235 };
236 let mut tasks = Vec::new();
237 for _ in 0..10 {
238 tasks.push(crate::spawn(shared.clone()));
239 }
240
241 crate::spawn(async move {
242 notify.notify_one();
243 for task in tasks {
244 task.await.unwrap();
245 }
246 })
247 .await
248 .unwrap()
249 }
250}