Skip to main content

futures_util/future/future/
shared.rs

1use crate::task::{waker_ref, ArcWake};
2use alloc::sync::{Arc, Weak};
3use core::cell::UnsafeCell;
4use core::fmt;
5use core::hash::Hasher;
6use core::pin::Pin;
7use core::ptr;
8use core::sync::atomic::AtomicUsize;
9use core::sync::atomic::Ordering::{Acquire, SeqCst};
10use futures_core::future::{FusedFuture, Future};
11use futures_core::task::{Context, Poll, Waker};
12use slab::Slab;
13
14#[cfg(feature = "std")]
15type Mutex<T> = std::sync::Mutex<T>;
16#[cfg(not(feature = "std"))]
17type Mutex<T> = spin::Mutex<T>;
18
19/// Future for the [`shared`](super::FutureExt::shared) method.
20#[must_use = "futures do nothing unless you `.await` or poll them"]
21pub struct Shared<Fut: Future> {
22    inner: Option<Arc<Inner<Fut>>>,
23    waker_key: usize,
24}
25
26struct Inner<Fut: Future> {
27    future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
28    notifier: Arc<Notifier>,
29}
30
31struct Notifier {
32    state: AtomicUsize,
33    wakers: Mutex<Option<Slab<Option<Waker>>>>,
34}
35
36/// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`.
37pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>);
38
39impl<Fut: Future> Clone for WeakShared<Fut> {
40    fn clone(&self) -> Self {
41        Self(self.0.clone())
42    }
43}
44
45impl<Fut: Future> fmt::Debug for Shared<Fut> {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        f.debug_struct("Shared")
48            .field("inner", &self.inner)
49            .field("waker_key", &self.waker_key)
50            .finish()
51    }
52}
53
54impl<Fut: Future> fmt::Debug for Inner<Fut> {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.debug_struct("Inner").finish()
57    }
58}
59
60impl<Fut: Future> fmt::Debug for WeakShared<Fut> {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.debug_struct("WeakShared").finish()
63    }
64}
65
66enum FutureOrOutput<Fut: Future> {
67    Future(Fut),
68    Output(Fut::Output),
69}
70
71unsafe impl<Fut> Send for Inner<Fut>
72where
73    Fut: Future + Send,
74    Fut::Output: Send + Sync,
75{
76}
77
78unsafe impl<Fut> Sync for Inner<Fut>
79where
80    Fut: Future + Send,
81    Fut::Output: Send + Sync,
82{
83}
84
85const IDLE: usize = 0;
86const POLLING: usize = 1;
87const COMPLETE: usize = 2;
88const POISONED: usize = 3;
89
90const NULL_WAKER_KEY: usize = usize::MAX;
91
92impl<Fut: Future> Shared<Fut> {
93    pub(super) fn new(future: Fut) -> Self {
94        let inner = Inner {
95            future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
96            notifier: Arc::new(Notifier {
97                state: AtomicUsize::new(IDLE),
98                wakers: Mutex::new(Some(Slab::new())),
99            }),
100        };
101
102        Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
103    }
104}
105
106impl<Fut> Shared<Fut>
107where
108    Fut: Future,
109{
110    /// Returns [`Some`] containing a reference to this [`Shared`]'s output if
111    /// it has already been computed by a clone or [`None`] if it hasn't been
112    /// computed yet or this [`Shared`] already returned its output from
113    /// [`poll`](Future::poll).
114    pub fn peek(&self) -> Option<&Fut::Output> {
115        if let Some(inner) = self.inner.as_ref() {
116            match inner.notifier.state.load(SeqCst) {
117                COMPLETE => unsafe { return Some(inner.output()) },
118                POISONED => panic!("inner future panicked during poll"),
119                _ => {}
120            }
121        }
122        None
123    }
124
125    /// Creates a new [`WeakShared`] for this [`Shared`].
126    ///
127    /// Returns [`None`] if it has already been polled to completion.
128    pub fn downgrade(&self) -> Option<WeakShared<Fut>> {
129        if let Some(inner) = self.inner.as_ref() {
130            return Some(WeakShared(Arc::downgrade(inner)));
131        }
132        None
133    }
134
135    /// Gets the number of strong pointers to this allocation.
136    ///
137    /// Returns [`None`] if it has already been polled to completion.
138    ///
139    /// # Safety
140    ///
141    /// This method by itself is safe, but using it correctly requires extra care. Another thread
142    /// can change the strong count at any time, including potentially between calling this method
143    /// and acting on the result.
144    #[allow(clippy::unnecessary_safety_doc)]
145    pub fn strong_count(&self) -> Option<usize> {
146        self.inner.as_ref().map(|arc| Arc::strong_count(arc))
147    }
148
149    /// Gets the number of weak pointers to this allocation.
150    ///
151    /// Returns [`None`] if it has already been polled to completion.
152    ///
153    /// # Safety
154    ///
155    /// This method by itself is safe, but using it correctly requires extra care. Another thread
156    /// can change the weak count at any time, including potentially between calling this method
157    /// and acting on the result.
158    #[allow(clippy::unnecessary_safety_doc)]
159    pub fn weak_count(&self) -> Option<usize> {
160        self.inner.as_ref().map(|arc| Arc::weak_count(arc))
161    }
162
163    /// Hashes the internal state of this `Shared` in a way that's compatible with `ptr_eq`.
164    pub fn ptr_hash<H: Hasher>(&self, state: &mut H) {
165        match self.inner.as_ref() {
166            Some(arc) => {
167                state.write_u8(1);
168                ptr::hash(Arc::as_ptr(arc), state);
169            }
170            None => {
171                state.write_u8(0);
172            }
173        }
174    }
175
176    /// Returns `true` if the two `Shared`s point to the same future (in a vein similar to
177    /// `Arc::ptr_eq`).
178    ///
179    /// Returns `false` if either `Shared` has terminated.
180    pub fn ptr_eq(&self, rhs: &Self) -> bool {
181        let lhs = match self.inner.as_ref() {
182            Some(lhs) => lhs,
183            None => return false,
184        };
185        let rhs = match rhs.inner.as_ref() {
186            Some(rhs) => rhs,
187            None => return false,
188        };
189        Arc::ptr_eq(lhs, rhs)
190    }
191}
192
193impl<Fut> Inner<Fut>
194where
195    Fut: Future,
196{
197    /// Safety: callers must first ensure that `self.inner.state`
198    /// is `COMPLETE`
199    unsafe fn output(&self) -> &Fut::Output {
200        match unsafe { &*self.future_or_output.get() } {
201            FutureOrOutput::Output(item) => item,
202            FutureOrOutput::Future(_) => unreachable!(),
203        }
204    }
205}
206
207impl<Fut> Inner<Fut>
208where
209    Fut: Future,
210    Fut::Output: Clone,
211{
212    /// Registers the current task to receive a wakeup when we are awoken.
213    fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
214        #[cfg(feature = "std")]
215        let mut wakers_guard = self.notifier.wakers.lock().unwrap();
216        #[cfg(not(feature = "std"))]
217        let mut wakers_guard = self.notifier.wakers.lock();
218
219        let wakers = match wakers_guard.as_mut() {
220            Some(wakers) => wakers,
221            None => return,
222        };
223
224        let new_waker = cx.waker();
225
226        if *waker_key == NULL_WAKER_KEY {
227            *waker_key = wakers.insert(Some(new_waker.clone()));
228        } else {
229            match wakers[*waker_key] {
230                Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
231                // Could use clone_from here, but Waker doesn't specialize it.
232                ref mut slot => *slot = Some(new_waker.clone()),
233            }
234        }
235        debug_assert!(*waker_key != NULL_WAKER_KEY);
236    }
237
238    /// Safety: callers must first ensure that `inner.state`
239    /// is `COMPLETE`
240    unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
241        match Arc::try_unwrap(self) {
242            Ok(inner) => match inner.future_or_output.into_inner() {
243                FutureOrOutput::Output(item) => item,
244                FutureOrOutput::Future(_) => unreachable!(),
245            },
246            Err(inner) => unsafe { inner.output().clone() },
247        }
248    }
249}
250
251impl<Fut> FusedFuture for Shared<Fut>
252where
253    Fut: Future,
254    Fut::Output: Clone,
255{
256    fn is_terminated(&self) -> bool {
257        self.inner.is_none()
258    }
259}
260
261impl<Fut> Future for Shared<Fut>
262where
263    Fut: Future,
264    Fut::Output: Clone,
265{
266    type Output = Fut::Output;
267
268    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
269        let this = &mut *self;
270
271        let inner = this.inner.take().expect("Shared future polled again after completion");
272
273        // Fast path for when the wrapped future has already completed
274        if inner.notifier.state.load(Acquire) == COMPLETE {
275            // Safety: We're in the COMPLETE state
276            return unsafe { Poll::Ready(inner.take_or_clone_output()) };
277        }
278
279        inner.record_waker(&mut this.waker_key, cx);
280
281        match inner
282            .notifier
283            .state
284            .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
285            .unwrap_or_else(|x| x)
286        {
287            IDLE => {
288                // Lock acquired, fall through
289            }
290            POLLING => {
291                // Another task is currently polling, at this point we just want
292                // to ensure that the waker for this task is registered
293                this.inner = Some(inner);
294                return Poll::Pending;
295            }
296            COMPLETE => {
297                // Safety: We're in the COMPLETE state
298                return unsafe { Poll::Ready(inner.take_or_clone_output()) };
299            }
300            POISONED => panic!("inner future panicked during poll"),
301            _ => unreachable!(),
302        }
303
304        let waker = waker_ref(&inner.notifier);
305        let mut cx = Context::from_waker(&waker);
306
307        struct Reset<'a> {
308            state: &'a AtomicUsize,
309            did_not_panic: bool,
310        }
311
312        impl Drop for Reset<'_> {
313            fn drop(&mut self) {
314                if !self.did_not_panic {
315                    self.state.store(POISONED, SeqCst);
316                }
317            }
318        }
319
320        let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
321
322        let output = {
323            let future = unsafe {
324                match &mut *inner.future_or_output.get() {
325                    FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
326                    _ => unreachable!(),
327                }
328            };
329
330            let poll_result = future.poll(&mut cx);
331            reset.did_not_panic = true;
332
333            match poll_result {
334                Poll::Pending => {
335                    if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
336                    {
337                        // Success
338                        drop(reset);
339                        this.inner = Some(inner);
340                        return Poll::Pending;
341                    } else {
342                        unreachable!()
343                    }
344                }
345                Poll::Ready(output) => output,
346            }
347        };
348
349        unsafe {
350            *inner.future_or_output.get() = FutureOrOutput::Output(output);
351        }
352
353        inner.notifier.state.store(COMPLETE, SeqCst);
354
355        // Wake all tasks and drop the slab
356        #[cfg(feature = "std")]
357        let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
358        #[cfg(not(feature = "std"))]
359        let mut wakers_guard = inner.notifier.wakers.lock();
360
361        let mut wakers = wakers_guard.take().unwrap();
362        for waker in wakers.drain().flatten() {
363            waker.wake();
364        }
365
366        drop(reset); // Make borrow checker happy
367        drop(wakers_guard);
368
369        // Safety: We're in the COMPLETE state
370        unsafe { Poll::Ready(inner.take_or_clone_output()) }
371    }
372}
373
374impl<Fut> Clone for Shared<Fut>
375where
376    Fut: Future,
377{
378    fn clone(&self) -> Self {
379        Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
380    }
381}
382
383impl<Fut> Drop for Shared<Fut>
384where
385    Fut: Future,
386{
387    fn drop(&mut self) {
388        if self.waker_key != NULL_WAKER_KEY {
389            if let Some(ref inner) = self.inner {
390                #[cfg(feature = "std")]
391                if let Ok(mut wakers) = inner.notifier.wakers.lock() {
392                    if let Some(wakers) = wakers.as_mut() {
393                        wakers.remove(self.waker_key);
394                    }
395                }
396                #[cfg(not(feature = "std"))]
397                if let Some(wakers) = inner.notifier.wakers.lock().as_mut() {
398                    wakers.remove(self.waker_key);
399                }
400            }
401        }
402    }
403}
404
405impl ArcWake for Notifier {
406    fn wake_by_ref(arc_self: &Arc<Self>) {
407        #[cfg(feature = "std")]
408        let wakers = &mut *arc_self.wakers.lock().unwrap();
409        #[cfg(not(feature = "std"))]
410        let wakers = &mut *arc_self.wakers.lock();
411
412        if let Some(wakers) = wakers.as_mut() {
413            for (_key, opt_waker) in wakers {
414                if let Some(waker) = opt_waker.take() {
415                    waker.wake();
416                }
417            }
418        }
419    }
420}
421
422impl<Fut: Future> WeakShared<Fut> {
423    /// Attempts to upgrade this [`WeakShared`] into a [`Shared`].
424    ///
425    /// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled
426    /// to completion.
427    pub fn upgrade(&self) -> Option<Shared<Fut>> {
428        Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
429    }
430}