par_stream/
shared_stream.rs

1use crate::common::*;
2use crossbeam::queue::SegQueue;
3use dashmap::DashMap;
4use futures::task::{waker_ref, ArcWake};
5use std::sync::Weak;
6
7// constants
8
9const IDLE: usize = 0;
10const POLLING: usize = 1;
11const COMPLETE: usize = 2;
12const POISONED: usize = 3;
13
14const NULL_WAKER_KEY: usize = usize::max_value();
15
16/// Stream for the [`shared`](super::StreamExt::shared) method.
17///
18/// The stream is cloneable. Polling the stream will poll the internal
19/// stream shared with the other owners. If there are multiple consumers
20/// for the shared stream, the items are sent in first-come-first-serve manner.
21#[must_use = "streams do nothing unless you consume or poll them"]
22pub struct Shared<St>
23where
24    St: ?Sized + Stream,
25{
26    inner: Option<Arc<Inner<St>>>,
27    waker_key: usize,
28}
29
30struct Inner<St>
31where
32    St: ?Sized + Stream,
33{
34    state: AtomicUsize,
35    notifier: Arc<Notifier>,
36    stream: UnsafeCell<St>,
37}
38
39struct Notifier {
40    /// The number of times the stream is awaken.
41    wake_count: AtomicUsize,
42    /// The list of pending waker keys.
43    pending_waker_keys: SegQueue<usize>,
44    /// The pairs of a waker key and a waker.
45    wakers: DashMap<usize, Waker>,
46}
47
48/// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`.
49pub struct WeakShared<St: Stream>(Weak<Inner<St>>);
50
51impl<St: Stream> Clone for WeakShared<St> {
52    fn clone(&self) -> Self {
53        Self(self.0.clone())
54    }
55}
56
57// The future itself is polled behind the `Arc`, so it won't be moved
58// when `Shared` is moved.
59impl<St: Stream> Unpin for Shared<St> {}
60
61impl<St: Stream> fmt::Debug for Shared<St> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("Shared")
64            .field("inner", &self.inner)
65            .field("waker_key", &self.waker_key)
66            .finish()
67    }
68}
69
70impl<St: Stream> fmt::Debug for Inner<St> {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        f.debug_struct("Inner").finish()
73    }
74}
75
76impl<St: Stream> fmt::Debug for WeakShared<St> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        f.debug_struct("WeakShared").finish()
79    }
80}
81
82unsafe impl<St> Send for Inner<St>
83where
84    St: Stream + Send,
85    St::Item: Send,
86{
87}
88
89unsafe impl<St> Sync for Inner<St>
90where
91    St: Stream + Send,
92    St::Item: Send,
93{
94}
95
96impl<St: Stream> Shared<St> {
97    pub fn new(stream: St) -> Self {
98        let inner = Inner {
99            stream: UnsafeCell::new(stream),
100            state: AtomicUsize::new(IDLE),
101            notifier: Arc::new(Notifier {
102                wake_count: AtomicUsize::new(0),
103                wakers: DashMap::new(),
104                pending_waker_keys: SegQueue::new(),
105            }),
106        };
107
108        Self {
109            inner: Some(Arc::new(inner)),
110            waker_key: NULL_WAKER_KEY,
111        }
112    }
113}
114
115impl<St> Shared<St>
116where
117    St: Stream,
118{
119    /// Creates a new [`WeakShared`] for this [`Shared`].
120    ///
121    /// Returns [`None`] if it has already been polled to completion.
122    pub fn downgrade(&self) -> Option<WeakShared<St>> {
123        if let Some(inner) = self.inner.as_ref() {
124            return Some(WeakShared(Arc::downgrade(inner)));
125        }
126        None
127    }
128
129    /// Gets the number of strong pointers to this allocation.
130    ///
131    /// Returns [`None`] if it has already been polled to completion.
132    ///
133    /// # Safety
134    ///
135    /// This method by itself is safe, but using it correctly requires extra care. Another thread
136    /// can change the strong count at any time, including potentially between calling this method
137    /// and acting on the result.
138    pub fn strong_count(&self) -> Option<usize> {
139        self.inner.as_ref().map(Arc::strong_count)
140    }
141
142    /// Gets the number of weak pointers to this allocation.
143    ///
144    /// Returns [`None`] if it has already been polled to completion.
145    ///
146    /// # Safety
147    ///
148    /// This method by itself is safe, but using it correctly requires extra care. Another thread
149    /// can change the weak count at any time, including potentially between calling this method
150    /// and acting on the result.
151    pub fn weak_count(&self) -> Option<usize> {
152        self.inner.as_ref().map(Arc::weak_count)
153    }
154}
155
156impl<St> Inner<St>
157where
158    St: Stream,
159{
160    /// Registers the current task to receive a wakeup when we are awoken.
161    fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
162        let notifier = &self.notifier;
163        let new_waker = cx.waker();
164
165        if *waker_key == NULL_WAKER_KEY {
166            *waker_key = next_waker_key();
167            notifier.wakers.insert(*waker_key, new_waker.clone());
168        } else {
169            use dashmap::mapref::entry::Entry as E;
170
171            match notifier.wakers.entry(*waker_key) {
172                E::Occupied(entry) => {
173                    let mut old_waker = entry.into_ref();
174
175                    if !new_waker.will_wake(&*old_waker) {
176                        *old_waker = new_waker.clone();
177                    }
178                }
179                E::Vacant(entry) => {
180                    entry.insert(new_waker.clone());
181                }
182            }
183        }
184        debug_assert!(*waker_key != NULL_WAKER_KEY);
185    }
186}
187
188impl<St> FusedStream for Shared<St>
189where
190    St: Stream,
191{
192    fn is_terminated(&self) -> bool {
193        self.inner.is_none()
194    }
195}
196
197impl<St> Stream for Shared<St>
198where
199    St: Stream,
200{
201    type Item = St::Item;
202
203    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        let this = &mut *self;
205
206        // Return end of stream if polled again after completion
207        let inner = match this.inner.take() {
208            Some(inner) => inner,
209            None => {
210                return Ready(None);
211            }
212        };
213
214        // Fast path for when the wrapped stream has already completed
215        if inner.state.load(Acquire) == COMPLETE {
216            return Ready(None);
217        }
218
219        // Make sure a waker key is registered for this waker.
220        inner.record_waker(&mut this.waker_key, cx);
221
222        // Transfer state: IDLE -> POLLING
223        match inner
224            .state
225            .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
226            .unwrap_or_else(|x| x)
227        {
228            IDLE => {
229                // Lock acquired, fall through
230            }
231            POLLING => {
232                // Another task is currently polling, at this point we just want
233                // to ensure that the waker for this task is registered
234                inner.notifier.register_pending(this.waker_key);
235                this.inner = Some(inner);
236                return Pending;
237            }
238            COMPLETE => {
239                return Ready(None);
240            }
241            POISONED => panic!("inner stream panicked during poll"),
242            _ => unreachable!(),
243        }
244
245        /* start of critical section (to the end of function) */
246
247        // the guard marks poisoned state when dropping if panic happened
248        let _reset = Reset(&inner.state);
249
250        // create context for underlying stream
251        let waker = waker_ref(&inner.notifier);
252        let mut stream_cx = Context::from_waker(&waker);
253
254        // get stream reference
255        let stream = unsafe {
256            let stream = &mut *inner.stream.get();
257            Pin::new_unchecked(stream)
258        };
259
260        // remember the wake count before polling
261        let wake_count = inner.notifier.wake_count();
262
263        match stream.poll_next(&mut stream_cx) {
264            Pending => {
265                // Transfer state: POLLING -> IDLE
266                inner.state.store(IDLE, SeqCst);
267
268                // Register the waker key to pending list.
269                let should_wake = inner
270                    .notifier
271                    .wake_or_register_pending(this.waker_key, wake_count);
272
273                // If the wake_count changed, indicating the stream wakes earlier, wake itself.
274                if should_wake {
275                    cx.waker().wake_by_ref();
276                }
277
278                drop(_reset);
279                this.inner = Some(inner);
280                Pending
281            }
282            Ready(Some(item)) => {
283                // Transfer state: POLLING -> IDLE
284                inner.state.store(IDLE, SeqCst);
285
286                // Wake pending tasks
287                inner.notifier.notify();
288
289                drop(_reset); // Make borrow checker happy
290                this.inner = Some(inner);
291                Ready(Some(item))
292            }
293            Ready(None) => {
294                // Transfer state: POLLING -> COMPLETE
295                inner.state.store(COMPLETE, SeqCst);
296
297                // Wake all tasks
298                inner.notifier.close(this.waker_key);
299                drop(_reset); // Make borrow checker happy
300
301                Ready(None)
302            }
303        }
304    }
305}
306
307impl<St> Clone for Shared<St>
308where
309    St: Stream,
310{
311    fn clone(&self) -> Self {
312        Self {
313            inner: self.inner.clone(),
314            waker_key: NULL_WAKER_KEY,
315        }
316    }
317}
318
319impl<St> Drop for Shared<St>
320where
321    St: ?Sized + Stream,
322{
323    fn drop(&mut self) {
324        if self.waker_key != NULL_WAKER_KEY {
325            if let Some(ref inner) = self.inner {
326                inner.notifier.wakers.remove(&self.waker_key);
327            }
328        }
329    }
330}
331
332impl ArcWake for Notifier {
333    fn wake_by_ref(this: &Arc<Self>) {
334        this.wake_count.fetch_add(1, SeqCst);
335        this.notify();
336    }
337}
338
339impl Notifier {
340    fn wake_count(&self) -> usize {
341        self.wake_count.load(Acquire)
342    }
343
344    /// Register the waker_key to pending list.
345    fn register_pending(&self, waker_key: usize) {
346        self.pending_waker_keys.push(waker_key);
347    }
348
349    /// Wake or register the waker_key to pending list according to expected wake count.
350    ///
351    /// The methods returns whether to wake or not.
352    fn wake_or_register_pending(&self, waker_key: usize, expected_wake_count: usize) -> bool {
353        debug_assert!(waker_key != NULL_WAKER_KEY);
354        self.pending_waker_keys.push(waker_key);
355        self.wake_count
356            .compare_exchange(expected_wake_count, expected_wake_count, SeqCst, SeqCst)
357            .is_err()
358    }
359
360    fn notify(&self) {
361        while let Some(waker_key) = self.pending_waker_keys.pop() {
362            if let Some(waker) = self.wakers.get(&waker_key) {
363                waker.wake_by_ref();
364            }
365        }
366    }
367
368    fn close(&self, waker_key: usize) {
369        debug_assert!(waker_key != NULL_WAKER_KEY);
370
371        self.wakers.retain(|&key, waker| {
372            if key != waker_key {
373                waker.wake_by_ref();
374            }
375            false
376        });
377    }
378}
379
380impl<St: Stream> WeakShared<St> {
381    /// Attempts to upgrade this [`WeakShared`] into a [`Shared`].
382    ///
383    /// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled
384    /// to completion.
385    pub fn upgrade(&self) -> Option<Shared<St>> {
386        Some(Shared {
387            inner: Some(self.0.upgrade()?),
388            waker_key: NULL_WAKER_KEY,
389        })
390    }
391}
392
393struct Reset<'a>(&'a AtomicUsize);
394
395impl Drop for Reset<'_> {
396    fn drop(&mut self) {
397        use std::thread;
398
399        if thread::panicking() {
400            self.0.store(POISONED, SeqCst);
401        }
402    }
403}
404
405fn next_waker_key() -> usize {
406    static KEY: AtomicUsize = AtomicUsize::new(0);
407    KEY.fetch_add(1, SeqCst)
408}