futures_util/future/
shared.rs1use std::{error, fmt, mem, ops};
23use std::cell::UnsafeCell;
24use std::sync::{Arc, Mutex};
25use std::sync::atomic::AtomicUsize;
26use std::sync::atomic::Ordering::SeqCst;
27use std::collections::HashMap;
28
29use futures_core::{Future, Poll, Async};
30use futures_core::task::{self, Wake, Waker, LocalMap};
31
32#[must_use = "futures do nothing unless polled"]
35pub struct Shared<F: Future> {
36 inner: Arc<Inner<F>>,
37 waiter: usize,
38}
39
40impl<F> fmt::Debug for Shared<F>
41 where F: Future + fmt::Debug,
42 F::Item: fmt::Debug,
43 F::Error: fmt::Debug,
44{
45 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
46 fmt.debug_struct("Shared")
47 .field("inner", &self.inner)
48 .field("waiter", &self.waiter)
49 .finish()
50 }
51}
52
53struct Inner<F: Future> {
54 next_clone_id: AtomicUsize,
55 future: UnsafeCell<Option<(F, LocalMap)>>,
56 result: UnsafeCell<Option<Result<SharedItem<F::Item>, SharedError<F::Error>>>>,
57 notifier: Arc<Notifier>,
58}
59
60struct Notifier {
61 state: AtomicUsize,
62 waiters: Mutex<HashMap<usize, Waker>>,
63}
64
65const IDLE: usize = 0;
66const POLLING: usize = 1;
67const REPOLL: usize = 2;
68const COMPLETE: usize = 3;
69const POISONED: usize = 4;
70
71pub fn new<F: Future>(future: F) -> Shared<F> {
72 Shared {
73 inner: Arc::new(Inner {
74 next_clone_id: AtomicUsize::new(1),
75 notifier: Arc::new(Notifier {
76 state: AtomicUsize::new(IDLE),
77 waiters: Mutex::new(HashMap::new()),
78 }),
79 future: UnsafeCell::new(Some((future, LocalMap::new()))),
80 result: UnsafeCell::new(None),
81 }),
82 waiter: 0,
83 }
84}
85
86impl<F> Shared<F> where F: Future {
87 pub fn peek(&self) -> Option<Result<SharedItem<F::Item>, SharedError<F::Error>>> {
91 match self.inner.notifier.state.load(SeqCst) {
92 COMPLETE => {
93 Some(unsafe { self.clone_result() })
94 }
95 POISONED => panic!("inner future panicked during poll"),
96 _ => None,
97 }
98 }
99
100 fn set_waiter(&mut self, cx: &mut task::Context) {
101 let mut waiters = self.inner.notifier.waiters.lock().unwrap();
102 waiters.insert(self.waiter, cx.waker().clone());
103 }
104
105 unsafe fn clone_result(&self) -> Result<SharedItem<F::Item>, SharedError<F::Error>> {
106 match *self.inner.result.get() {
107 Some(Ok(ref item)) => Ok(SharedItem { item: item.item.clone() }),
108 Some(Err(ref e)) => Err(SharedError { error: e.error.clone() }),
109 _ => unreachable!(),
110 }
111 }
112
113 fn complete(&self) {
114 unsafe { *self.inner.future.get() = None };
115 self.inner.notifier.state.store(COMPLETE, SeqCst);
116 Wake::wake(&self.inner.notifier);
117 }
118}
119
120impl<F> Future for Shared<F>
121 where F: Future
122{
123 type Item = SharedItem<F::Item>;
124 type Error = SharedError<F::Error>;
125
126 fn poll(&mut self, cx: &mut task::Context) -> Poll<Self::Item, Self::Error> {
127 self.set_waiter(cx);
128
129 match self.inner.notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) {
130 IDLE => {
131 }
133 POLLING | REPOLL => {
134 return Ok(Async::Pending);
138 }
139 COMPLETE => {
140 return unsafe { self.clone_result().map(Async::Ready) };
141 }
142 POISONED => panic!("inner future panicked during poll"),
143 _ => unreachable!(),
144 }
145
146 loop {
147 struct Reset<'a>(&'a AtomicUsize);
148
149 impl<'a> Drop for Reset<'a> {
150 fn drop(&mut self) {
151 use std::thread;
152
153 if thread::panicking() {
154 self.0.store(POISONED, SeqCst);
155 }
156 }
157 }
158
159 let _reset = Reset(&self.inner.notifier.state);
160
161 let res = unsafe {
163 let (ref mut f, ref mut data) = *(*self.inner.future.get()).as_mut().unwrap();
164 let waker = Waker::from(self.inner.notifier.clone());
165 let mut cx = task::Context::new(data, &waker, cx.executor());
166 f.poll(&mut cx)
167 };
168 match res {
169 Ok(Async::Pending) => {
170 match self.inner.notifier.state.compare_and_swap(POLLING, IDLE, SeqCst) {
172 POLLING => {
173 return Ok(Async::Pending);
175 }
176 REPOLL => {
177 let prev = self.inner.notifier.state.swap(POLLING, SeqCst);
179 assert_eq!(prev, REPOLL);
180 }
181 _ => unreachable!(),
182 }
183
184 }
185 Ok(Async::Ready(i)) => {
186 unsafe {
187 (*self.inner.result.get()) = Some(Ok(SharedItem { item: Arc::new(i) }));
188 }
189
190 break;
191 }
192 Err(e) => {
193 unsafe {
194 (*self.inner.result.get()) = Some(Err(SharedError { error: Arc::new(e) }));
195 }
196
197 break;
198 }
199 }
200 }
201
202 self.complete();
203 unsafe { self.clone_result().map(Async::Ready) }
204 }
205}
206
207impl<F> Clone for Shared<F> where F: Future {
208 fn clone(&self) -> Self {
209 let next_clone_id = self.inner.next_clone_id.fetch_add(1, SeqCst);
210
211 Shared {
212 inner: self.inner.clone(),
213 waiter: next_clone_id,
214 }
215 }
216}
217
218impl<F> Drop for Shared<F> where F: Future {
219 fn drop(&mut self) {
220 let mut waiters = self.inner.notifier.waiters.lock().unwrap();
221 waiters.remove(&self.waiter);
222 }
223}
224
225impl Wake for Notifier {
226 fn wake(arc_self: &Arc<Self>) {
227 arc_self.state.compare_and_swap(POLLING, REPOLL, SeqCst);
228
229 let waiters = mem::replace(&mut *arc_self.waiters.lock().unwrap(), HashMap::new());
230
231 for (_, waiter) in waiters {
232 waiter.wake();
233 }
234 }
235}
236
237unsafe impl<F> Sync for Inner<F>
238 where F: Future + Send,
239 F::Item: Send + Sync,
240 F::Error: Send + Sync
241{}
242
243unsafe impl<F> Send for Inner<F>
244 where F: Future + Send,
245 F::Item: Send + Sync,
246 F::Error: Send + Sync
247{}
248
249impl<F> fmt::Debug for Inner<F>
250 where F: Future + fmt::Debug,
251 F::Item: fmt::Debug,
252 F::Error: fmt::Debug,
253{
254 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
255 fmt.debug_struct("Inner")
256 .finish()
257 }
258}
259
260#[derive(Clone, Debug)]
263pub struct SharedItem<T> {
264 item: Arc<T>,
265}
266
267impl<T> SharedItem<T> {
268 pub fn into_inner(self) -> Arc<T> {
270 self.item
271 }
272}
273
274impl<T> ops::Deref for SharedItem<T> {
275 type Target = T;
276
277 fn deref(&self) -> &T {
278 &self.item.as_ref()
279 }
280}
281
282#[derive(Clone, Debug)]
285pub struct SharedError<E> {
286 error: Arc<E>,
287}
288
289impl<E> SharedError<E> {
290 pub fn into_inner(self) -> Arc<E> {
292 self.error
293 }
294}
295
296impl<E> ops::Deref for SharedError<E> {
297 type Target = E;
298
299 fn deref(&self) -> &E {
300 &self.error.as_ref()
301 }
302}
303
304impl<E> fmt::Display for SharedError<E>
305 where E: fmt::Display,
306{
307 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
308 self.error.fmt(f)
309 }
310}
311
312impl<E> error::Error for SharedError<E>
313 where E: error::Error,
314{
315 fn description(&self) -> &str {
316 self.error.description()
317 }
318
319 fn cause(&self) -> Option<&error::Error> {
320 self.error.cause()
321 }
322}