1use std::cell::UnsafeCell;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, Weak};
6use std::task::{Context, Poll};
7
8use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
9
10#[must_use = "futures do nothing unless you `.await` or poll them"]
11pub struct Shared<Fut: Future> {
12 inner: Option<Arc<Inner<Fut>>>,
13 permit_fut: Option<SyncBoxFuture<Result<OwnedSemaphorePermit, AcquireError>>>,
14 permit: Option<OwnedSemaphorePermit>,
15}
16
17type SyncBoxFuture<T> = Pin<Box<dyn Future<Output = T> + Sync + Send + 'static>>;
18
19impl<Fut: Future> Clone for Shared<Fut> {
20 fn clone(&self) -> Self {
21 Self {
22 inner: self.inner.clone(),
23 permit_fut: None,
24 permit: None,
25 }
26 }
27}
28
29impl<Fut: Future> Shared<Fut> {
30 pub fn new(future: Fut) -> Self {
31 let semaphore = Arc::new(Semaphore::new(1));
32 let inner = Arc::new(Inner {
33 state: AtomicUsize::new(POLLING),
34 future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
35 semaphore,
36 });
37
38 Self {
39 inner: Some(inner),
40 permit_fut: None,
41 permit: None,
42 }
43 }
44
45 pub fn weak_future(&self) -> Option<WeakShared<Fut>> {
46 self.inner.as_ref().map(|inner| WeakShared {
47 inner: Some(Arc::downgrade(inner)),
48 permit_fut: None,
49 permit: None,
50 })
51 }
52
53 pub fn downgrade(&self) -> Option<WeakSharedHandle<Fut>> {
54 self.inner
55 .as_ref()
56 .map(|inner| WeakSharedHandle(Arc::downgrade(inner)))
57 }
58
59 pub fn consume(mut self) -> bool {
61 self.inner
62 .take()
63 .map(|inner| Arc::into_inner(inner).is_some())
64 .unwrap_or_default()
65 }
66}
67
68fn poll_impl<'cx, Fut>(
69 this_inner: &mut Option<Arc<Inner<Fut>>>,
70 this_permit_fut: &mut Option<SyncBoxFuture<Result<OwnedSemaphorePermit, AcquireError>>>,
71 this_permit: &mut Option<OwnedSemaphorePermit>,
72 cx: &mut Context<'cx>,
73) -> Poll<(Fut::Output, bool)>
74where
75 Fut: Future,
76 Fut::Output: Clone,
77{
78 let inner = this_inner
79 .take()
80 .expect("Shared future polled again after completion");
81
82 if inner.state.load(Ordering::Acquire) == COMPLETE {
84 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
86 }
87
88 if this_permit.is_none() {
89 *this_permit = Some('permit: {
90 let permit_fut = if let Some(fut) = this_permit_fut.as_mut() {
92 fut
93 } else {
94 match Arc::clone(&inner.semaphore).try_acquire_owned() {
96 Ok(permit) => break 'permit permit,
97 Err(TryAcquireError::NoPermits) => {}
98 Err(TryAcquireError::Closed) => unreachable!(),
100 }
101
102 let next_fut = Arc::clone(&inner.semaphore).acquire_owned();
103 this_permit_fut.get_or_insert(Box::pin(next_fut))
104 };
105
106 match permit_fut.as_mut().poll(cx) {
108 Poll::Pending => {
109 *this_inner = Some(inner);
110 return Poll::Pending;
111 }
112 Poll::Ready(Ok(permit)) => {
113 *this_permit_fut = None;
115 permit
116 }
117 Poll::Ready(Err(_e)) => unreachable!(),
119 }
120 });
121 }
122
123 assert!(this_permit_fut.is_none(), "permit already acquired");
124
125 match inner.state.load(Ordering::Acquire) {
126 COMPLETE => {
127 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
129 }
130 POISONED => panic!("inner future panicked during poll"),
131 _ => {}
132 }
133
134 struct Reset<'a> {
136 state: &'a AtomicUsize,
137 did_not_panic: bool,
138 }
139
140 impl Drop for Reset<'_> {
141 fn drop(&mut self) {
142 if !self.did_not_panic {
143 self.state.store(POISONED, Ordering::Release);
144 }
145 }
146 }
147
148 let mut reset = Reset {
149 state: &inner.state,
150 did_not_panic: false,
151 };
152
153 let output = {
154 let future = unsafe {
156 match &mut *inner.future_or_output.get() {
157 FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
158 FutureOrOutput::Output(_) => unreachable!(),
159 }
160 };
161
162 let poll_result = future.poll(cx);
163 reset.did_not_panic = true;
164
165 match poll_result {
166 Poll::Pending => {
167 drop(reset); *this_inner = Some(inner);
169 return Poll::Pending;
170 }
171 Poll::Ready(output) => output,
172 }
173 };
174
175 unsafe {
176 *inner.future_or_output.get() = FutureOrOutput::Output(output);
177 }
178
179 inner.state.store(COMPLETE, Ordering::Release);
180
181 drop(reset); unsafe { Poll::Ready(inner.take_or_clone_output()) }
187}
188
189impl<Fut> Future for Shared<Fut>
190where
191 Fut: Future,
192 Fut::Output: Clone,
193{
194 type Output = (Fut::Output, bool);
195
196 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197 let Shared {
198 inner,
199 permit_fut,
200 permit,
201 } = &mut *self;
202
203 poll_impl(inner, permit_fut, permit, cx)
204 }
205}
206
207#[must_use = "futures do nothing unless you `.await` or poll them"]
210pub struct WeakShared<Fut: Future> {
211 inner: Option<Weak<Inner<Fut>>>,
212 permit_fut: Option<SyncBoxFuture<Result<OwnedSemaphorePermit, AcquireError>>>,
213 permit: Option<OwnedSemaphorePermit>,
214}
215
216impl<Fut> Future for WeakShared<Fut>
217where
218 Fut: Future,
219 Fut::Output: Clone,
220{
221 type Output = Option<(Fut::Output, bool)>;
222
223 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
224 let WeakShared {
225 inner,
226 permit_fut,
227 permit,
228 } = &mut *self;
229
230 let weak_inner = inner
231 .take()
232 .expect("Weak shared future polled again after completion");
233
234 let mut strong_inner = weak_inner.upgrade();
235
236 if strong_inner.is_none() {
237 return Poll::Ready(None);
238 };
239
240 let poll_result = poll_impl(&mut strong_inner, permit_fut, permit, cx);
241
242 *inner = strong_inner.is_some().then_some(weak_inner);
243
244 poll_result.map(Some)
245 }
246}
247
248#[repr(transparent)]
251pub struct WeakSharedHandle<Fut: Future>(Weak<Inner<Fut>>);
252
253impl<Fut: Future> WeakSharedHandle<Fut> {
254 pub fn upgrade(&self) -> Option<Shared<Fut>> {
255 self.0.upgrade().map(|inner| Shared {
256 inner: Some(inner),
257 permit_fut: None,
258 permit: None,
259 })
260 }
261
262 pub fn strong_count(&self) -> usize {
263 self.0.strong_count()
264 }
265}
266
267struct Inner<Fut: Future> {
268 state: AtomicUsize,
269 future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
270 semaphore: Arc<Semaphore>,
271}
272
273impl<Fut> Inner<Fut>
274where
275 Fut: Future,
276 Fut::Output: Clone,
277{
278 unsafe fn take_or_clone_output(self: Arc<Self>) -> (Fut::Output, bool) {
281 match Arc::try_unwrap(self) {
282 Ok(inner) => match inner.future_or_output.into_inner() {
283 FutureOrOutput::Output(item) => (item, true),
284 FutureOrOutput::Future(_) => unreachable!(),
285 },
286 Err(inner) => match unsafe { &*inner.future_or_output.get() } {
287 FutureOrOutput::Output(item) => (item.clone(), false),
288 FutureOrOutput::Future(_) => unreachable!(),
289 },
290 }
291 }
292}
293
294unsafe impl<Fut> Send for Inner<Fut>
295where
296 Fut: Future + Send,
297 Fut::Output: Send + Sync,
298{
299}
300
301unsafe impl<Fut> Sync for Inner<Fut>
302where
303 Fut: Future + Send,
304 Fut::Output: Send + Sync,
305{
306}
307
308enum FutureOrOutput<Fut: Future> {
309 Future(Fut),
310 Output(Fut::Output),
311}
312
313const POLLING: usize = 0;
314const COMPLETE: usize = 2;
315const POISONED: usize = 3;
316
317#[cfg(test)]
318mod tests {
319 use futures_util::FutureExt;
323
324 use super::*;
325
326 async fn yield_now() {
327 struct YieldNow {
329 yielded: bool,
330 }
331
332 impl Future for YieldNow {
333 type Output = ();
334
335 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
336 if self.yielded {
337 return Poll::Ready(());
338 }
339
340 self.yielded = true;
341 cx.waker().wake_by_ref();
342 Poll::Pending
343 }
344 }
345
346 YieldNow { yielded: false }.await;
347 }
348
349 #[tokio::test(flavor = "multi_thread")]
350 async fn must_not_hang_up() {
351 for _ in 0..200 {
352 for _ in 0..1000 {
353 test_fut().await;
354 }
355 }
356 println!();
357 }
358
359 async fn test_fut() {
360 let f1 = Shared::new(yield_now());
361 let f2 = f1.clone();
362 let x1 = tokio::spawn(async move {
363 f1.now_or_never();
364 });
365 let x2 = tokio::spawn(async move {
366 f2.await;
367 });
368 x1.await.ok();
369 x2.await.ok();
370 }
371}