par_stream/
shared_stream.rs1use crate::common::*;
2use crossbeam::queue::SegQueue;
3use dashmap::DashMap;
4use futures::task::{waker_ref, ArcWake};
5use std::sync::Weak;
6
7const IDLE: usize = 0;
10const POLLING: usize = 1;
11const COMPLETE: usize = 2;
12const POISONED: usize = 3;
13
14const NULL_WAKER_KEY: usize = usize::max_value();
15
16#[must_use = "streams do nothing unless you consume or poll them"]
22pub struct Shared<St>
23where
24 St: ?Sized + Stream,
25{
26 inner: Option<Arc<Inner<St>>>,
27 waker_key: usize,
28}
29
30struct Inner<St>
31where
32 St: ?Sized + Stream,
33{
34 state: AtomicUsize,
35 notifier: Arc<Notifier>,
36 stream: UnsafeCell<St>,
37}
38
39struct Notifier {
40 wake_count: AtomicUsize,
42 pending_waker_keys: SegQueue<usize>,
44 wakers: DashMap<usize, Waker>,
46}
47
48pub struct WeakShared<St: Stream>(Weak<Inner<St>>);
50
51impl<St: Stream> Clone for WeakShared<St> {
52 fn clone(&self) -> Self {
53 Self(self.0.clone())
54 }
55}
56
57impl<St: Stream> Unpin for Shared<St> {}
60
61impl<St: Stream> fmt::Debug for Shared<St> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_struct("Shared")
64 .field("inner", &self.inner)
65 .field("waker_key", &self.waker_key)
66 .finish()
67 }
68}
69
70impl<St: Stream> fmt::Debug for Inner<St> {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 f.debug_struct("Inner").finish()
73 }
74}
75
76impl<St: Stream> fmt::Debug for WeakShared<St> {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 f.debug_struct("WeakShared").finish()
79 }
80}
81
82unsafe impl<St> Send for Inner<St>
83where
84 St: Stream + Send,
85 St::Item: Send,
86{
87}
88
89unsafe impl<St> Sync for Inner<St>
90where
91 St: Stream + Send,
92 St::Item: Send,
93{
94}
95
96impl<St: Stream> Shared<St> {
97 pub fn new(stream: St) -> Self {
98 let inner = Inner {
99 stream: UnsafeCell::new(stream),
100 state: AtomicUsize::new(IDLE),
101 notifier: Arc::new(Notifier {
102 wake_count: AtomicUsize::new(0),
103 wakers: DashMap::new(),
104 pending_waker_keys: SegQueue::new(),
105 }),
106 };
107
108 Self {
109 inner: Some(Arc::new(inner)),
110 waker_key: NULL_WAKER_KEY,
111 }
112 }
113}
114
115impl<St> Shared<St>
116where
117 St: Stream,
118{
119 pub fn downgrade(&self) -> Option<WeakShared<St>> {
123 if let Some(inner) = self.inner.as_ref() {
124 return Some(WeakShared(Arc::downgrade(inner)));
125 }
126 None
127 }
128
129 pub fn strong_count(&self) -> Option<usize> {
139 self.inner.as_ref().map(Arc::strong_count)
140 }
141
142 pub fn weak_count(&self) -> Option<usize> {
152 self.inner.as_ref().map(Arc::weak_count)
153 }
154}
155
156impl<St> Inner<St>
157where
158 St: Stream,
159{
160 fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
162 let notifier = &self.notifier;
163 let new_waker = cx.waker();
164
165 if *waker_key == NULL_WAKER_KEY {
166 *waker_key = next_waker_key();
167 notifier.wakers.insert(*waker_key, new_waker.clone());
168 } else {
169 use dashmap::mapref::entry::Entry as E;
170
171 match notifier.wakers.entry(*waker_key) {
172 E::Occupied(entry) => {
173 let mut old_waker = entry.into_ref();
174
175 if !new_waker.will_wake(&*old_waker) {
176 *old_waker = new_waker.clone();
177 }
178 }
179 E::Vacant(entry) => {
180 entry.insert(new_waker.clone());
181 }
182 }
183 }
184 debug_assert!(*waker_key != NULL_WAKER_KEY);
185 }
186}
187
188impl<St> FusedStream for Shared<St>
189where
190 St: Stream,
191{
192 fn is_terminated(&self) -> bool {
193 self.inner.is_none()
194 }
195}
196
197impl<St> Stream for Shared<St>
198where
199 St: Stream,
200{
201 type Item = St::Item;
202
203 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204 let this = &mut *self;
205
206 let inner = match this.inner.take() {
208 Some(inner) => inner,
209 None => {
210 return Ready(None);
211 }
212 };
213
214 if inner.state.load(Acquire) == COMPLETE {
216 return Ready(None);
217 }
218
219 inner.record_waker(&mut this.waker_key, cx);
221
222 match inner
224 .state
225 .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
226 .unwrap_or_else(|x| x)
227 {
228 IDLE => {
229 }
231 POLLING => {
232 inner.notifier.register_pending(this.waker_key);
235 this.inner = Some(inner);
236 return Pending;
237 }
238 COMPLETE => {
239 return Ready(None);
240 }
241 POISONED => panic!("inner stream panicked during poll"),
242 _ => unreachable!(),
243 }
244
245 let _reset = Reset(&inner.state);
249
250 let waker = waker_ref(&inner.notifier);
252 let mut stream_cx = Context::from_waker(&waker);
253
254 let stream = unsafe {
256 let stream = &mut *inner.stream.get();
257 Pin::new_unchecked(stream)
258 };
259
260 let wake_count = inner.notifier.wake_count();
262
263 match stream.poll_next(&mut stream_cx) {
264 Pending => {
265 inner.state.store(IDLE, SeqCst);
267
268 let should_wake = inner
270 .notifier
271 .wake_or_register_pending(this.waker_key, wake_count);
272
273 if should_wake {
275 cx.waker().wake_by_ref();
276 }
277
278 drop(_reset);
279 this.inner = Some(inner);
280 Pending
281 }
282 Ready(Some(item)) => {
283 inner.state.store(IDLE, SeqCst);
285
286 inner.notifier.notify();
288
289 drop(_reset); this.inner = Some(inner);
291 Ready(Some(item))
292 }
293 Ready(None) => {
294 inner.state.store(COMPLETE, SeqCst);
296
297 inner.notifier.close(this.waker_key);
299 drop(_reset); Ready(None)
302 }
303 }
304 }
305}
306
307impl<St> Clone for Shared<St>
308where
309 St: Stream,
310{
311 fn clone(&self) -> Self {
312 Self {
313 inner: self.inner.clone(),
314 waker_key: NULL_WAKER_KEY,
315 }
316 }
317}
318
319impl<St> Drop for Shared<St>
320where
321 St: ?Sized + Stream,
322{
323 fn drop(&mut self) {
324 if self.waker_key != NULL_WAKER_KEY {
325 if let Some(ref inner) = self.inner {
326 inner.notifier.wakers.remove(&self.waker_key);
327 }
328 }
329 }
330}
331
332impl ArcWake for Notifier {
333 fn wake_by_ref(this: &Arc<Self>) {
334 this.wake_count.fetch_add(1, SeqCst);
335 this.notify();
336 }
337}
338
339impl Notifier {
340 fn wake_count(&self) -> usize {
341 self.wake_count.load(Acquire)
342 }
343
344 fn register_pending(&self, waker_key: usize) {
346 self.pending_waker_keys.push(waker_key);
347 }
348
349 fn wake_or_register_pending(&self, waker_key: usize, expected_wake_count: usize) -> bool {
353 debug_assert!(waker_key != NULL_WAKER_KEY);
354 self.pending_waker_keys.push(waker_key);
355 self.wake_count
356 .compare_exchange(expected_wake_count, expected_wake_count, SeqCst, SeqCst)
357 .is_err()
358 }
359
360 fn notify(&self) {
361 while let Some(waker_key) = self.pending_waker_keys.pop() {
362 if let Some(waker) = self.wakers.get(&waker_key) {
363 waker.wake_by_ref();
364 }
365 }
366 }
367
368 fn close(&self, waker_key: usize) {
369 debug_assert!(waker_key != NULL_WAKER_KEY);
370
371 self.wakers.retain(|&key, waker| {
372 if key != waker_key {
373 waker.wake_by_ref();
374 }
375 false
376 });
377 }
378}
379
380impl<St: Stream> WeakShared<St> {
381 pub fn upgrade(&self) -> Option<Shared<St>> {
386 Some(Shared {
387 inner: Some(self.0.upgrade()?),
388 waker_key: NULL_WAKER_KEY,
389 })
390 }
391}
392
393struct Reset<'a>(&'a AtomicUsize);
394
395impl Drop for Reset<'_> {
396 fn drop(&mut self) {
397 use std::thread;
398
399 if thread::panicking() {
400 self.0.store(POISONED, SeqCst);
401 }
402 }
403}
404
405fn next_waker_key() -> usize {
406 static KEY: AtomicUsize = AtomicUsize::new(0);
407 KEY.fetch_add(1, SeqCst)
408}