futures_util/future/
shared.rs

1//! Definition of the Shared combinator, a future that is cloneable,
2//! and can be polled in multiple threads.
3//!
4//! # Examples
5//!
6//! ```
7//! # extern crate futures;
8//! # extern crate futures_executor;
9//! use futures::prelude::*;
10//! use futures::future;
11//! use futures_executor::block_on;
12//!
13//! # fn main() {
14//! let future = future::ok::<_, bool>(6);
15//! let shared1 = future.shared();
16//! let shared2 = shared1.clone();
17//! assert_eq!(6, *block_on(shared1).unwrap());
18//! assert_eq!(6, *block_on(shared2).unwrap());
19//! # }
20//! ```
21
22use std::{error, fmt, mem, ops};
23use std::cell::UnsafeCell;
24use std::sync::{Arc, Mutex};
25use std::sync::atomic::AtomicUsize;
26use std::sync::atomic::Ordering::SeqCst;
27use std::collections::HashMap;
28
29use futures_core::{Future, Poll, Async};
30use futures_core::task::{self, Wake, Waker, LocalMap};
31
32/// A future that is cloneable and can be polled in multiple threads.
33/// Use `Future::shared()` method to convert any future into a `Shared` future.
34#[must_use = "futures do nothing unless polled"]
35pub struct Shared<F: Future> {
36    inner: Arc<Inner<F>>,
37    waiter: usize,
38}
39
40impl<F> fmt::Debug for Shared<F>
41    where F: Future + fmt::Debug,
42          F::Item: fmt::Debug,
43          F::Error: fmt::Debug,
44{
45    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
46        fmt.debug_struct("Shared")
47            .field("inner", &self.inner)
48            .field("waiter", &self.waiter)
49            .finish()
50    }
51}
52
53struct Inner<F: Future> {
54    next_clone_id: AtomicUsize,
55    future: UnsafeCell<Option<(F, LocalMap)>>,
56    result: UnsafeCell<Option<Result<SharedItem<F::Item>, SharedError<F::Error>>>>,
57    notifier: Arc<Notifier>,
58}
59
60struct Notifier {
61    state: AtomicUsize,
62    waiters: Mutex<HashMap<usize, Waker>>,
63}
64
65const IDLE: usize = 0;
66const POLLING: usize = 1;
67const REPOLL: usize = 2;
68const COMPLETE: usize = 3;
69const POISONED: usize = 4;
70
71pub fn new<F: Future>(future: F) -> Shared<F> {
72    Shared {
73        inner: Arc::new(Inner {
74            next_clone_id: AtomicUsize::new(1),
75            notifier: Arc::new(Notifier {
76                state: AtomicUsize::new(IDLE),
77                waiters: Mutex::new(HashMap::new()),
78            }),
79            future: UnsafeCell::new(Some((future, LocalMap::new()))),
80            result: UnsafeCell::new(None),
81        }),
82        waiter: 0,
83    }
84}
85
86impl<F> Shared<F> where F: Future {
87    /// If any clone of this `Shared` has completed execution, returns its result immediately
88    /// without blocking. Otherwise, returns None without triggering the work represented by
89    /// this `Shared`.
90    pub fn peek(&self) -> Option<Result<SharedItem<F::Item>, SharedError<F::Error>>> {
91        match self.inner.notifier.state.load(SeqCst) {
92            COMPLETE => {
93                Some(unsafe { self.clone_result() })
94            }
95            POISONED => panic!("inner future panicked during poll"),
96            _ => None,
97        }
98    }
99
100    fn set_waiter(&mut self, cx: &mut task::Context) {
101        let mut waiters = self.inner.notifier.waiters.lock().unwrap();
102        waiters.insert(self.waiter, cx.waker().clone());
103    }
104
105    unsafe fn clone_result(&self) -> Result<SharedItem<F::Item>, SharedError<F::Error>> {
106        match *self.inner.result.get() {
107            Some(Ok(ref item)) => Ok(SharedItem { item: item.item.clone() }),
108            Some(Err(ref e)) => Err(SharedError { error: e.error.clone() }),
109            _ => unreachable!(),
110        }
111    }
112
113    fn complete(&self) {
114        unsafe { *self.inner.future.get() = None };
115        self.inner.notifier.state.store(COMPLETE, SeqCst);
116        Wake::wake(&self.inner.notifier);
117    }
118}
119
120impl<F> Future for Shared<F>
121    where F: Future
122{
123    type Item = SharedItem<F::Item>;
124    type Error = SharedError<F::Error>;
125
126    fn poll(&mut self, cx: &mut task::Context) -> Poll<Self::Item, Self::Error> {
127        self.set_waiter(cx);
128
129        match self.inner.notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) {
130            IDLE => {
131                // Lock acquired, fall through
132            }
133            POLLING | REPOLL => {
134                // Another task is currently polling, at this point we just want
135                // to ensure that our task handle is currently registered
136
137                return Ok(Async::Pending);
138            }
139            COMPLETE => {
140                return unsafe { self.clone_result().map(Async::Ready) };
141            }
142            POISONED => panic!("inner future panicked during poll"),
143            _ => unreachable!(),
144        }
145
146        loop {
147            struct Reset<'a>(&'a AtomicUsize);
148
149            impl<'a> Drop for Reset<'a> {
150                fn drop(&mut self) {
151                    use std::thread;
152
153                    if thread::panicking() {
154                        self.0.store(POISONED, SeqCst);
155                    }
156                }
157            }
158
159            let _reset = Reset(&self.inner.notifier.state);
160
161            // Poll the future
162            let res = unsafe {
163                let (ref mut f, ref mut data) = *(*self.inner.future.get()).as_mut().unwrap();
164                let waker = Waker::from(self.inner.notifier.clone());
165                let mut cx = task::Context::new(data, &waker, cx.executor());
166                f.poll(&mut cx)
167            };
168            match res {
169                Ok(Async::Pending) => {
170                    // Not ready, try to release the handle
171                    match self.inner.notifier.state.compare_and_swap(POLLING, IDLE, SeqCst) {
172                        POLLING => {
173                            // Success
174                            return Ok(Async::Pending);
175                        }
176                        REPOLL => {
177                            // Gotta poll again!
178                            let prev = self.inner.notifier.state.swap(POLLING, SeqCst);
179                            assert_eq!(prev, REPOLL);
180                        }
181                        _ => unreachable!(),
182                    }
183
184                }
185                Ok(Async::Ready(i)) => {
186                    unsafe {
187                        (*self.inner.result.get()) = Some(Ok(SharedItem { item: Arc::new(i) }));
188                    }
189
190                    break;
191                }
192                Err(e) => {
193                    unsafe {
194                        (*self.inner.result.get()) = Some(Err(SharedError { error: Arc::new(e) }));
195                    }
196
197                    break;
198                }
199            }
200        }
201
202        self.complete();
203        unsafe { self.clone_result().map(Async::Ready) }
204    }
205}
206
207impl<F> Clone for Shared<F> where F: Future {
208    fn clone(&self) -> Self {
209        let next_clone_id = self.inner.next_clone_id.fetch_add(1, SeqCst);
210
211        Shared {
212            inner: self.inner.clone(),
213            waiter: next_clone_id,
214        }
215    }
216}
217
218impl<F> Drop for Shared<F> where F: Future {
219    fn drop(&mut self) {
220        let mut waiters = self.inner.notifier.waiters.lock().unwrap();
221        waiters.remove(&self.waiter);
222    }
223}
224
225impl Wake for Notifier {
226    fn wake(arc_self: &Arc<Self>) {
227        arc_self.state.compare_and_swap(POLLING, REPOLL, SeqCst);
228
229        let waiters = mem::replace(&mut *arc_self.waiters.lock().unwrap(), HashMap::new());
230
231        for (_, waiter) in waiters {
232            waiter.wake();
233        }
234    }
235}
236
237unsafe impl<F> Sync for Inner<F>
238    where F: Future + Send,
239          F::Item: Send + Sync,
240          F::Error: Send + Sync
241{}
242
243unsafe impl<F> Send for Inner<F>
244    where F: Future + Send,
245          F::Item: Send + Sync,
246          F::Error: Send + Sync
247{}
248
249impl<F> fmt::Debug for Inner<F>
250    where F: Future + fmt::Debug,
251          F::Item: fmt::Debug,
252          F::Error: fmt::Debug,
253{
254    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
255        fmt.debug_struct("Inner")
256            .finish()
257    }
258}
259
260/// A wrapped item of the original future that is cloneable and implements Deref
261/// for ease of use.
262#[derive(Clone, Debug)]
263pub struct SharedItem<T> {
264    item: Arc<T>,
265}
266
267impl<T> SharedItem<T> {
268    /// Expose the inner Arc<T>
269    pub fn into_inner(self) -> Arc<T> {
270        self.item
271    }
272}
273
274impl<T> ops::Deref for SharedItem<T> {
275    type Target = T;
276
277    fn deref(&self) -> &T {
278        &self.item.as_ref()
279    }
280}
281
282/// A wrapped error of the original future that is cloneable and implements Deref
283/// for ease of use.
284#[derive(Clone, Debug)]
285pub struct SharedError<E> {
286    error: Arc<E>,
287}
288
289impl<E> SharedError<E> {
290    /// Expose the inner Arc<E>
291    pub fn into_inner(self) -> Arc<E> {
292        self.error
293    }
294}
295
296impl<E> ops::Deref for SharedError<E> {
297    type Target = E;
298
299    fn deref(&self) -> &E {
300        &self.error.as_ref()
301    }
302}
303
304impl<E> fmt::Display for SharedError<E>
305    where E: fmt::Display,
306{
307    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
308        self.error.fmt(f)
309    }
310}
311
312impl<E> error::Error for SharedError<E>
313    where E: error::Error,
314{
315    fn description(&self) -> &str {
316        self.error.description()
317    }
318
319    fn cause(&self) -> Option<&error::Error> {
320        self.error.cause()
321    }
322}