futures_util/future/future/
shared.rs1use crate::task::{waker_ref, ArcWake};
2use alloc::sync::{Arc, Weak};
3use core::cell::UnsafeCell;
4use core::fmt;
5use core::hash::Hasher;
6use core::pin::Pin;
7use core::ptr;
8use core::sync::atomic::AtomicUsize;
9use core::sync::atomic::Ordering::{Acquire, SeqCst};
10use futures_core::future::{FusedFuture, Future};
11use futures_core::task::{Context, Poll, Waker};
12use slab::Slab;
13
14#[cfg(feature = "std")]
15type Mutex<T> = std::sync::Mutex<T>;
16#[cfg(not(feature = "std"))]
17type Mutex<T> = spin::Mutex<T>;
18
19#[must_use = "futures do nothing unless you `.await` or poll them"]
21pub struct Shared<Fut: Future> {
22 inner: Option<Arc<Inner<Fut>>>,
23 waker_key: usize,
24}
25
26struct Inner<Fut: Future> {
27 future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
28 notifier: Arc<Notifier>,
29}
30
31struct Notifier {
32 state: AtomicUsize,
33 wakers: Mutex<Option<Slab<Option<Waker>>>>,
34}
35
36pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>);
38
39impl<Fut: Future> Clone for WeakShared<Fut> {
40 fn clone(&self) -> Self {
41 Self(self.0.clone())
42 }
43}
44
45impl<Fut: Future> fmt::Debug for Shared<Fut> {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.debug_struct("Shared")
48 .field("inner", &self.inner)
49 .field("waker_key", &self.waker_key)
50 .finish()
51 }
52}
53
54impl<Fut: Future> fmt::Debug for Inner<Fut> {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.debug_struct("Inner").finish()
57 }
58}
59
60impl<Fut: Future> fmt::Debug for WeakShared<Fut> {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.debug_struct("WeakShared").finish()
63 }
64}
65
66enum FutureOrOutput<Fut: Future> {
67 Future(Fut),
68 Output(Fut::Output),
69}
70
71unsafe impl<Fut> Send for Inner<Fut>
72where
73 Fut: Future + Send,
74 Fut::Output: Send + Sync,
75{
76}
77
78unsafe impl<Fut> Sync for Inner<Fut>
79where
80 Fut: Future + Send,
81 Fut::Output: Send + Sync,
82{
83}
84
85const IDLE: usize = 0;
86const POLLING: usize = 1;
87const COMPLETE: usize = 2;
88const POISONED: usize = 3;
89
90const NULL_WAKER_KEY: usize = usize::MAX;
91
92impl<Fut: Future> Shared<Fut> {
93 pub(super) fn new(future: Fut) -> Self {
94 let inner = Inner {
95 future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
96 notifier: Arc::new(Notifier {
97 state: AtomicUsize::new(IDLE),
98 wakers: Mutex::new(Some(Slab::new())),
99 }),
100 };
101
102 Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
103 }
104}
105
106impl<Fut> Shared<Fut>
107where
108 Fut: Future,
109{
110 pub fn peek(&self) -> Option<&Fut::Output> {
115 if let Some(inner) = self.inner.as_ref() {
116 match inner.notifier.state.load(SeqCst) {
117 COMPLETE => unsafe { return Some(inner.output()) },
118 POISONED => panic!("inner future panicked during poll"),
119 _ => {}
120 }
121 }
122 None
123 }
124
125 pub fn downgrade(&self) -> Option<WeakShared<Fut>> {
129 if let Some(inner) = self.inner.as_ref() {
130 return Some(WeakShared(Arc::downgrade(inner)));
131 }
132 None
133 }
134
135 #[allow(clippy::unnecessary_safety_doc)]
145 pub fn strong_count(&self) -> Option<usize> {
146 self.inner.as_ref().map(|arc| Arc::strong_count(arc))
147 }
148
149 #[allow(clippy::unnecessary_safety_doc)]
159 pub fn weak_count(&self) -> Option<usize> {
160 self.inner.as_ref().map(|arc| Arc::weak_count(arc))
161 }
162
163 pub fn ptr_hash<H: Hasher>(&self, state: &mut H) {
165 match self.inner.as_ref() {
166 Some(arc) => {
167 state.write_u8(1);
168 ptr::hash(Arc::as_ptr(arc), state);
169 }
170 None => {
171 state.write_u8(0);
172 }
173 }
174 }
175
176 pub fn ptr_eq(&self, rhs: &Self) -> bool {
181 let lhs = match self.inner.as_ref() {
182 Some(lhs) => lhs,
183 None => return false,
184 };
185 let rhs = match rhs.inner.as_ref() {
186 Some(rhs) => rhs,
187 None => return false,
188 };
189 Arc::ptr_eq(lhs, rhs)
190 }
191}
192
193impl<Fut> Inner<Fut>
194where
195 Fut: Future,
196{
197 unsafe fn output(&self) -> &Fut::Output {
200 match unsafe { &*self.future_or_output.get() } {
201 FutureOrOutput::Output(item) => item,
202 FutureOrOutput::Future(_) => unreachable!(),
203 }
204 }
205}
206
207impl<Fut> Inner<Fut>
208where
209 Fut: Future,
210 Fut::Output: Clone,
211{
212 fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
214 #[cfg(feature = "std")]
215 let mut wakers_guard = self.notifier.wakers.lock().unwrap();
216 #[cfg(not(feature = "std"))]
217 let mut wakers_guard = self.notifier.wakers.lock();
218
219 let wakers = match wakers_guard.as_mut() {
220 Some(wakers) => wakers,
221 None => return,
222 };
223
224 let new_waker = cx.waker();
225
226 if *waker_key == NULL_WAKER_KEY {
227 *waker_key = wakers.insert(Some(new_waker.clone()));
228 } else {
229 match wakers[*waker_key] {
230 Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
231 ref mut slot => *slot = Some(new_waker.clone()),
233 }
234 }
235 debug_assert!(*waker_key != NULL_WAKER_KEY);
236 }
237
238 unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
241 match Arc::try_unwrap(self) {
242 Ok(inner) => match inner.future_or_output.into_inner() {
243 FutureOrOutput::Output(item) => item,
244 FutureOrOutput::Future(_) => unreachable!(),
245 },
246 Err(inner) => unsafe { inner.output().clone() },
247 }
248 }
249}
250
251impl<Fut> FusedFuture for Shared<Fut>
252where
253 Fut: Future,
254 Fut::Output: Clone,
255{
256 fn is_terminated(&self) -> bool {
257 self.inner.is_none()
258 }
259}
260
261impl<Fut> Future for Shared<Fut>
262where
263 Fut: Future,
264 Fut::Output: Clone,
265{
266 type Output = Fut::Output;
267
268 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
269 let this = &mut *self;
270
271 let inner = this.inner.take().expect("Shared future polled again after completion");
272
273 if inner.notifier.state.load(Acquire) == COMPLETE {
275 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
277 }
278
279 inner.record_waker(&mut this.waker_key, cx);
280
281 match inner
282 .notifier
283 .state
284 .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
285 .unwrap_or_else(|x| x)
286 {
287 IDLE => {
288 }
290 POLLING => {
291 this.inner = Some(inner);
294 return Poll::Pending;
295 }
296 COMPLETE => {
297 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
299 }
300 POISONED => panic!("inner future panicked during poll"),
301 _ => unreachable!(),
302 }
303
304 let waker = waker_ref(&inner.notifier);
305 let mut cx = Context::from_waker(&waker);
306
307 struct Reset<'a> {
308 state: &'a AtomicUsize,
309 did_not_panic: bool,
310 }
311
312 impl Drop for Reset<'_> {
313 fn drop(&mut self) {
314 if !self.did_not_panic {
315 self.state.store(POISONED, SeqCst);
316 }
317 }
318 }
319
320 let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
321
322 let output = {
323 let future = unsafe {
324 match &mut *inner.future_or_output.get() {
325 FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
326 _ => unreachable!(),
327 }
328 };
329
330 let poll_result = future.poll(&mut cx);
331 reset.did_not_panic = true;
332
333 match poll_result {
334 Poll::Pending => {
335 if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
336 {
337 drop(reset);
339 this.inner = Some(inner);
340 return Poll::Pending;
341 } else {
342 unreachable!()
343 }
344 }
345 Poll::Ready(output) => output,
346 }
347 };
348
349 unsafe {
350 *inner.future_or_output.get() = FutureOrOutput::Output(output);
351 }
352
353 inner.notifier.state.store(COMPLETE, SeqCst);
354
355 #[cfg(feature = "std")]
357 let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
358 #[cfg(not(feature = "std"))]
359 let mut wakers_guard = inner.notifier.wakers.lock();
360
361 let mut wakers = wakers_guard.take().unwrap();
362 for waker in wakers.drain().flatten() {
363 waker.wake();
364 }
365
366 drop(reset); drop(wakers_guard);
368
369 unsafe { Poll::Ready(inner.take_or_clone_output()) }
371 }
372}
373
374impl<Fut> Clone for Shared<Fut>
375where
376 Fut: Future,
377{
378 fn clone(&self) -> Self {
379 Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
380 }
381}
382
383impl<Fut> Drop for Shared<Fut>
384where
385 Fut: Future,
386{
387 fn drop(&mut self) {
388 if self.waker_key != NULL_WAKER_KEY {
389 if let Some(ref inner) = self.inner {
390 #[cfg(feature = "std")]
391 if let Ok(mut wakers) = inner.notifier.wakers.lock() {
392 if let Some(wakers) = wakers.as_mut() {
393 wakers.remove(self.waker_key);
394 }
395 }
396 #[cfg(not(feature = "std"))]
397 if let Some(wakers) = inner.notifier.wakers.lock().as_mut() {
398 wakers.remove(self.waker_key);
399 }
400 }
401 }
402 }
403}
404
405impl ArcWake for Notifier {
406 fn wake_by_ref(arc_self: &Arc<Self>) {
407 #[cfg(feature = "std")]
408 let wakers = &mut *arc_self.wakers.lock().unwrap();
409 #[cfg(not(feature = "std"))]
410 let wakers = &mut *arc_self.wakers.lock();
411
412 if let Some(wakers) = wakers.as_mut() {
413 for (_key, opt_waker) in wakers {
414 if let Some(waker) = opt_waker.take() {
415 waker.wake();
416 }
417 }
418 }
419 }
420}
421
422impl<Fut: Future> WeakShared<Fut> {
423 pub fn upgrade(&self) -> Option<Shared<Fut>> {
428 Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
429 }
430}