futures_concurrency/stream/
stream_group.rs

1use alloc::collections::BTreeSet;
2use core::fmt::{self, Debug};
3use core::ops::{Deref, DerefMut};
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use futures_core::Stream;
7use slab::Slab;
8use smallvec::{smallvec, SmallVec};
9
10use crate::utils::{PollState, PollVec, WakerVec};
11
12/// A growable group of streams which act as a single unit.
13///
14/// # Example
15///
16/// **Basic example**
17///
18/// ```rust
19/// use futures_concurrency::stream::StreamGroup;
20/// use futures_lite::{stream, StreamExt};
21///
22/// # futures_lite::future::block_on(async {
23/// let mut group = StreamGroup::new();
24/// group.insert(stream::once(2));
25/// group.insert(stream::once(4));
26///
27/// let mut out = 0;
28/// while let Some(num) = group.next().await {
29///     out += num;
30/// }
31/// assert_eq!(out, 6);
32/// # });
33/// ```
34///
35/// **Update the group on every iteration**
36///
37/// ```rust
38/// use futures_concurrency::stream::StreamGroup;
39/// use lending_stream::prelude::*;
40/// use futures_lite::stream;
41///
42/// # futures_lite::future::block_on(async {
43/// let mut group = StreamGroup::new();
44/// group.insert(stream::once(4));
45///
46/// let mut index = 3;
47/// let mut out = 0;
48/// let mut group = group.lend_mut();
49/// while let Some((group, num)) = group.next().await {
50///     if index != 0 {
51///         group.insert(stream::once(index));
52///         index -= 1;
53///     }
54///     out += num;
55/// }
56/// assert_eq!(out, 10);
57/// # });
58/// ```
59#[must_use = "`StreamGroup` does nothing if not iterated over"]
60#[derive(Default)]
61#[pin_project::pin_project]
62pub struct StreamGroup<S> {
63    #[pin]
64    streams: Slab<S>,
65    wakers: WakerVec,
66    states: PollVec,
67    keys: BTreeSet<usize>,
68    key_removal_queue: SmallVec<[usize; 10]>,
69    capacity: usize,
70}
71
72impl<T: Debug> Debug for StreamGroup<T> {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        f.debug_struct("StreamGroup")
75            .field("slab", &"[..]")
76            .finish()
77    }
78}
79
80impl<S> StreamGroup<S> {
81    /// Create a new instance of `StreamGroup`.
82    ///
83    /// # Example
84    ///
85    /// ```rust
86    /// use futures_concurrency::stream::StreamGroup;
87    ///
88    /// let group = StreamGroup::new();
89    /// # let group: StreamGroup<usize> = group;
90    /// ```
91    pub fn new() -> Self {
92        Self::with_capacity(0)
93    }
94
95    /// Create a new instance of `StreamGroup` with a given capacity.
96    ///
97    /// # Example
98    ///
99    /// ```rust
100    /// use futures_concurrency::stream::StreamGroup;
101    ///
102    /// let group = StreamGroup::with_capacity(2);
103    /// # let group: StreamGroup<usize> = group;
104    /// ```
105    pub fn with_capacity(capacity: usize) -> Self {
106        Self {
107            streams: Slab::with_capacity(capacity),
108            wakers: WakerVec::new(capacity),
109            states: PollVec::new(capacity),
110            keys: BTreeSet::new(),
111            key_removal_queue: smallvec![],
112            capacity,
113        }
114    }
115
116    /// Return the number of futures currently active in the group.
117    ///
118    /// # Example
119    ///
120    /// ```rust
121    /// use futures_concurrency::stream::StreamGroup;
122    /// use futures_lite::stream;
123    ///
124    /// let mut group = StreamGroup::with_capacity(2);
125    /// assert_eq!(group.len(), 0);
126    /// group.insert(stream::once(12));
127    /// assert_eq!(group.len(), 1);
128    /// ```
129    #[inline(always)]
130    pub fn len(&self) -> usize {
131        self.streams.len()
132    }
133
134    /// Return the capacity of the `StreamGroup`.
135    ///
136    /// # Example
137    ///
138    /// ```rust
139    /// use futures_concurrency::stream::StreamGroup;
140    /// use futures_lite::stream;
141    ///
142    /// let group = StreamGroup::with_capacity(2);
143    /// assert_eq!(group.capacity(), 2);
144    /// # let group: StreamGroup<usize> = group;
145    /// ```
146    pub fn capacity(&self) -> usize {
147        self.capacity
148    }
149
150    /// Returns true if there are no futures currently active in the group.
151    ///
152    /// # Example
153    ///
154    /// ```rust
155    /// use futures_concurrency::stream::StreamGroup;
156    /// use futures_lite::stream;
157    ///
158    /// let mut group = StreamGroup::with_capacity(2);
159    /// assert!(group.is_empty());
160    /// group.insert(stream::once(12));
161    /// assert!(!group.is_empty());
162    /// ```
163    #[inline(always)]
164    pub fn is_empty(&self) -> bool {
165        self.streams.is_empty()
166    }
167
168    /// Removes a stream from the group. Returns whether the value was present in
169    /// the group.
170    ///
171    /// # Example
172    ///
173    /// ```
174    /// use futures_lite::stream;
175    /// use futures_concurrency::stream::StreamGroup;
176    ///
177    /// # futures_lite::future::block_on(async {
178    /// let mut group = StreamGroup::new();
179    /// let key = group.insert(stream::once(4));
180    /// assert_eq!(group.len(), 1);
181    /// group.remove(key);
182    /// assert_eq!(group.len(), 0);
183    /// # })
184    /// ```
185    pub fn remove(&mut self, key: Key) -> bool {
186        let is_present = self.keys.remove(&key.0);
187        if is_present {
188            self.states[key.0].set_none();
189            self.streams.remove(key.0);
190        }
191        is_present
192    }
193
194    /// Returns `true` if the `StreamGroup` contains a value for the specified key.
195    ///
196    /// # Example
197    ///
198    /// ```
199    /// use futures_lite::stream;
200    /// use futures_concurrency::stream::StreamGroup;
201    ///
202    /// # futures_lite::future::block_on(async {
203    /// let mut group = StreamGroup::new();
204    /// let key = group.insert(stream::once(4));
205    /// assert!(group.contains_key(key));
206    /// group.remove(key);
207    /// assert!(!group.contains_key(key));
208    /// # })
209    /// ```
210    pub fn contains_key(&mut self, key: Key) -> bool {
211        self.keys.contains(&key.0)
212    }
213
214    /// Reserves capacity for `additional` more streams to be inserted.
215    /// Does nothing if the capacity is already sufficient.
216    ///
217    /// # Example
218    ///
219    /// ```rust
220    /// use futures_concurrency::stream::StreamGroup;
221    /// use futures_lite::stream::Once;
222    /// # futures_lite::future::block_on(async {
223    /// let mut group: StreamGroup<Once<usize>> = StreamGroup::with_capacity(0);
224    /// assert_eq!(group.capacity(), 0);
225    /// group.reserve(10);
226    /// assert_eq!(group.capacity(), 10);
227    ///
228    /// // does nothing if capacity is sufficient
229    /// group.reserve(5);
230    /// assert_eq!(group.capacity(), 10);
231    /// # })
232    /// ```
233    pub fn reserve(&mut self, additional: usize) {
234        if self.len() + additional < self.capacity {
235            return;
236        }
237        let new_cap = self.capacity + additional;
238        self.wakers.resize(new_cap);
239        self.states.resize(new_cap);
240        self.streams.reserve_exact(additional);
241        self.capacity = new_cap;
242    }
243}
244
245impl<S: Stream> StreamGroup<S> {
246    /// Insert a new future into the group.
247    ///
248    /// # Example
249    ///
250    /// ```rust
251    /// use futures_concurrency::stream::StreamGroup;
252    /// use futures_lite::stream;
253    ///
254    /// let mut group = StreamGroup::with_capacity(2);
255    /// group.insert(stream::once(12));
256    /// ```
257    pub fn insert(&mut self, stream: S) -> Key
258    where
259        S: Stream,
260    {
261        if self.capacity <= self.len() {
262            self.reserve(self.capacity * 2 + 1);
263        }
264
265        let index = self.streams.insert(stream);
266        self.keys.insert(index);
267
268        // Set the corresponding state
269        self.states[index].set_pending();
270        self.wakers.readiness().set_ready(index);
271
272        Key(index)
273    }
274
275    /// Create a stream which also yields the key of each item.
276    ///
277    /// # Example
278    ///
279    /// ```rust
280    /// use futures_concurrency::stream::StreamGroup;
281    /// use futures_lite::{stream, StreamExt};
282    ///
283    /// # futures_lite::future::block_on(async {
284    /// let mut group = StreamGroup::new();
285    /// group.insert(stream::once(2));
286    /// group.insert(stream::once(4));
287    ///
288    /// let mut out = 0;
289    /// let mut group = group.keyed();
290    /// while let Some((_key, num)) = group.next().await {
291    ///     out += num;
292    /// }
293    /// assert_eq!(out, 6);
294    /// # });
295    /// ```
296    pub fn keyed(self) -> Keyed<S> {
297        Keyed { group: self }
298    }
299}
300
301impl<S: Stream> StreamGroup<S> {
302    fn poll_next_inner(
303        mut self: Pin<&mut Self>,
304        cx: &Context<'_>,
305    ) -> Poll<Option<(Key, <S as Stream>::Item)>> {
306        let mut this = self.as_mut().project();
307
308        // Short-circuit if we have no streams to iterate over
309        if this.streams.is_empty() {
310            return Poll::Ready(None);
311        }
312
313        // Set the top-level waker and check readiness
314        let mut readiness = this.wakers.readiness();
315        readiness.set_waker(cx.waker());
316        if !readiness.any_ready() {
317            // Nothing is ready yet
318            return Poll::Pending;
319        }
320
321        // Setup our stream state
322        let mut ret = Poll::Pending;
323        let mut done_count = 0;
324        let stream_count = this.streams.len();
325        let states = this.states;
326
327        // SAFETY: We unpin the stream set so we can later individually access
328        // single streams. Either to read from them or to drop them.
329        let streams = unsafe { this.streams.as_mut().get_unchecked_mut() };
330
331        for index in this.keys.iter().cloned() {
332            if states[index].is_pending() && readiness.clear_ready(index) {
333                // unlock readiness so we don't deadlock when polling
334                #[allow(clippy::drop_non_drop)]
335                drop(readiness);
336
337                // Obtain the intermediate waker.
338                let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
339
340                // SAFETY: this stream here is a projection from the streams
341                // vec, which we're reading from.
342                let stream = unsafe { Pin::new_unchecked(&mut streams[index]) };
343                match stream.poll_next(&mut cx) {
344                    Poll::Ready(Some(item)) => {
345                        // Set the return type for the function
346                        ret = Poll::Ready(Some((Key(index), item)));
347
348                        // We just obtained an item from this index, make sure
349                        // we check it again on a next iteration
350                        states[index] = PollState::Pending;
351                        let mut readiness = this.wakers.readiness();
352                        readiness.set_ready(index);
353
354                        break;
355                    }
356                    Poll::Ready(None) => {
357                        // A stream has ended, make note of that
358                        done_count += 1;
359
360                        // Remove all associated data about the stream.
361                        // The only data we can't remove directly is the key entry.
362                        states[index] = PollState::None;
363                        streams.remove(index);
364                        this.key_removal_queue.push(index);
365                    }
366                    // Keep looping if there is nothing for us to do
367                    Poll::Pending => {}
368                };
369
370                // Lock readiness so we can use it again
371                readiness = this.wakers.readiness();
372            }
373        }
374
375        // Now that we're no longer borrowing `this.keys` we can loop over
376        // which items we need to remove
377        if !this.key_removal_queue.is_empty() {
378            for key in this.key_removal_queue.iter() {
379                this.keys.remove(key);
380            }
381            this.key_removal_queue.clear();
382        }
383
384        // If all streams turned up with `Poll::Ready(None)` our
385        // stream should return that
386        if done_count == stream_count {
387            ret = Poll::Ready(None);
388        }
389
390        ret
391    }
392}
393
394impl<S: Stream> Stream for StreamGroup<S> {
395    type Item = <S as Stream>::Item;
396
397    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        match self.poll_next_inner(cx) {
399            Poll::Ready(Some((_key, item))) => Poll::Ready(Some(item)),
400            Poll::Ready(None) => Poll::Ready(None),
401            Poll::Pending => Poll::Pending,
402        }
403    }
404}
405
406impl<S: Stream> FromIterator<S> for StreamGroup<S> {
407    fn from_iter<T: IntoIterator<Item = S>>(iter: T) -> Self {
408        let iter = iter.into_iter();
409        let len = iter.size_hint().1.unwrap_or_default();
410        let mut this = Self::with_capacity(len);
411        for stream in iter {
412            this.insert(stream);
413        }
414        this
415    }
416}
417
418/// A key used to index into the `StreamGroup` type.
419#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
420pub struct Key(usize);
421
422/// Iterate over items in the stream group with their associated keys.
423#[derive(Debug)]
424#[pin_project::pin_project]
425pub struct Keyed<S: Stream> {
426    #[pin]
427    group: StreamGroup<S>,
428}
429
430impl<S: Stream> Deref for Keyed<S> {
431    type Target = StreamGroup<S>;
432
433    fn deref(&self) -> &Self::Target {
434        &self.group
435    }
436}
437
438impl<S: Stream> DerefMut for Keyed<S> {
439    fn deref_mut(&mut self) -> &mut Self::Target {
440        &mut self.group
441    }
442}
443
444impl<S: Stream> Stream for Keyed<S> {
445    type Item = (Key, <S as Stream>::Item);
446
447    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
448        let mut this = self.project();
449        this.group.as_mut().poll_next_inner(cx)
450    }
451}
452
453#[cfg(test)]
454mod test {
455    use super::StreamGroup;
456    use futures_lite::{prelude::*, stream};
457
458    #[test]
459    fn smoke() {
460        futures_lite::future::block_on(async {
461            let mut group = StreamGroup::new();
462            group.insert(stream::once(2));
463            group.insert(stream::once(4));
464
465            let mut out = 0;
466            while let Some(num) = group.next().await {
467                out += num;
468            }
469            assert_eq!(out, 6);
470            assert_eq!(group.len(), 0);
471            assert!(group.is_empty());
472        });
473    }
474
475    #[test]
476    fn capacity_grow_on_insert() {
477        futures_lite::future::block_on(async {
478            let mut group = StreamGroup::new();
479            let cap = group.capacity();
480
481            group.insert(stream::once(1));
482
483            assert!(group.capacity() > cap);
484        });
485    }
486}