pollable_map/
stream.rs

1pub mod optional;
2pub mod set;
3pub mod timeout_map;
4pub mod timeout_set;
5
6use crate::common::InnerMap;
7use futures::stream::{FusedStream, SelectAll};
8use futures::{Stream, StreamExt};
9use std::pin::Pin;
10use std::task::{Context, Poll, Waker};
11
12/// Combining multiple streams into one, with each stream having a unique key.
13pub struct StreamMap<K, S> {
14    list: SelectAll<InnerMap<K, S>>,
15    empty: bool,
16    waker: Option<Waker>,
17}
18
19impl<K, T> Default for StreamMap<K, T>
20where
21    K: Clone + Unpin,
22    T: Stream + Send + Unpin + 'static,
23{
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl<K, T> StreamMap<K, T>
30where
31    K: Clone + Unpin,
32    T: Stream + Send + Unpin + 'static,
33{
34    /// Creates an empty [`StreamMap`]
35    pub fn new() -> Self {
36        Self {
37            list: SelectAll::new(),
38            empty: true,
39            waker: None,
40        }
41    }
42}
43
44impl<K, T> StreamMap<K, T>
45where
46    K: Clone + PartialEq + Send + Unpin + 'static,
47    T: Stream + Send + Unpin + 'static,
48{
49    /// Insert a stream into the map with a unique key.
50    /// The function will return true if the map does not have the key present,
51    /// otherwise it will return false
52    pub fn insert(&mut self, key: K, stream: T) -> bool {
53        if self.contains_key(&key) {
54            return false;
55        }
56
57        let st = InnerMap::new(key, stream);
58        self.list.push(st);
59
60        if let Some(waker) = self.waker.take() {
61            waker.wake();
62        }
63
64        self.empty = false;
65        true
66    }
67
68    /// Mark stream with assigned key to wake up on successful yield.
69    /// Will return false if stream does not exist or if value is the same as
70    /// previously set.
71    pub fn set_wake_on_success(&mut self, key: &K, wake_on_success: bool) -> bool {
72        self.list
73            .iter_mut()
74            .find(|st| st.key().eq(key))
75            .is_some_and(|st| st.set_wake_on_success(wake_on_success))
76    }
77
78    /// An iterator visiting all key-value pairs in arbitrary order.
79    pub fn iter(&self) -> impl Iterator<Item = (&K, &T)> {
80        self.list.iter().filter_map(|st| st.key_value())
81    }
82
83    /// An iterator visiting all key-value pairs mutably in arbitrary order.
84    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&K, &mut T)> {
85        self.list.iter_mut().filter_map(|st| st.key_value_mut())
86    }
87
88    /// An iterator visiting all key-value pairs with a pinned valued in arbitrary order
89    pub fn iter_pin(&mut self) -> impl Iterator<Item = (&K, Pin<&mut T>)> {
90        self.list.iter_mut().filter_map(|st| st.key_value_pin())
91    }
92
93    /// Returns an iterator visiting all keys in arbitrary order.
94    pub fn keys(&self) -> impl Iterator<Item = &K> {
95        self.list.iter().map(|st| st.key())
96    }
97
98    /// An iterator visiting all values in arbitrary order.
99    pub fn values(&self) -> impl Iterator<Item = &T> {
100        self.list.iter().filter_map(|st| st.inner())
101    }
102
103    /// An iterator visiting all values mutably in arbitrary order.
104    pub fn values_mut(&mut self) -> impl Iterator<Item = &mut T> {
105        self.list.iter_mut().filter_map(|st| st.inner_mut())
106    }
107
108    /// Returns `true` if the map contains a stream for the specified key.
109    pub fn contains_key(&self, key: &K) -> bool {
110        self.list.iter().any(|st| st.key().eq(key))
111    }
112
113    /// Clears the map.
114    pub fn clear(&mut self) {
115        self.list.clear();
116    }
117
118    /// Returns a reference to the stream corresponding to the key.
119    pub fn get(&self, key: &K) -> Option<&T> {
120        self.list
121            .iter()
122            .find(|st| st.key().eq(key))
123            .and_then(|st| st.inner())
124    }
125
126    /// Returns a mutable stream to the value corresponding to the key.
127    pub fn get_mut(&mut self, key: &K) -> Option<&mut T> {
128        self.list
129            .iter_mut()
130            .find(|st| st.key().eq(key))
131            .and_then(|st| st.inner_mut())
132    }
133
134    /// Returns a muable stream or default value if it does not exist.
135    pub fn get_mut_or_default(&mut self, key: &K) -> &mut T
136    where
137        T: Default,
138    {
139        self.insert(key.clone(), T::default());
140        self.get_mut(key).expect("valid entry")
141    }
142
143    /// Returns a pinned stream corresponding to the key.
144    pub fn get_pinned(&mut self, key: &K) -> Option<Pin<&mut T>> {
145        self.list
146            .iter_mut()
147            .find(|st| st.key().eq(key))
148            .and_then(|st| st.inner_pin())
149    }
150
151    /// Removes a key from the map, returning the stream.
152    pub fn remove(&mut self, key: &K) -> Option<T> {
153        self.list
154            .iter_mut()
155            .find(|st| st.key().eq(key))
156            .and_then(|st| st.take_inner())
157    }
158
159    /// Returns the number of streams in the map.
160    pub fn len(&self) -> usize {
161        self.list.iter().filter(|st| st.inner().is_some()).count()
162    }
163
164    /// Return `true` map contains no elements.
165    pub fn is_empty(&self) -> bool {
166        self.list.is_empty() || self.list.iter().all(|st| st.inner().is_none())
167    }
168}
169
170impl<K, T> FromIterator<(K, T)> for StreamMap<K, T>
171where
172    K: Clone + PartialEq + Send + Unpin + 'static,
173    T: Stream + Send + Unpin + 'static,
174{
175    fn from_iter<I: IntoIterator<Item = (K, T)>>(iter: I) -> Self {
176        let mut maps = Self::new();
177        for (key, val) in iter {
178            maps.insert(key, val);
179        }
180        maps
181    }
182}
183
184impl<K, T> Stream for StreamMap<K, T>
185where
186    K: Clone + PartialEq + Send + Unpin + 'static,
187    T: Stream + Unpin + Send + 'static,
188{
189    type Item = (K, T::Item);
190
191    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        let this = &mut *self;
193
194        if this.list.is_empty() {
195            self.waker = Some(cx.waker().clone());
196            return Poll::Pending;
197        }
198
199        loop {
200            match this.list.poll_next_unpin(cx) {
201                Poll::Ready(Some((key, Some(item)))) => return Poll::Ready(Some((key, item))),
202                // We continue in case there is any progress on the set of streams
203                Poll::Ready(Some((key, None))) => {
204                    this.remove(&key);
205                }
206                Poll::Ready(None) => {
207                    // While we could allow the stream to continue to be pending, it would make more sense to notify that the stream
208                    // is empty without needing to explicitly check while polling the actual "map" itself
209                    // So we would mark a field to notify that the state is finished and return `Poll::Ready(None)` so the stream
210                    // can be terminated while on the next poll, we could let it be return pending.
211                    // We do this so that we are not returning `Poll::Ready(None)` each time the map is polled
212                    // as that may be seen as UB and may cause an increase in cpu usage
213                    if self.empty {
214                        self.waker = Some(cx.waker().clone());
215                        return Poll::Pending;
216                    }
217
218                    self.empty = true;
219                    return Poll::Ready(None);
220                }
221                Poll::Pending => {
222                    // Returning `None` does not mean the stream is actually terminated
223                    self.waker = Some(cx.waker().clone());
224                    return Poll::Pending;
225                }
226            }
227        }
228    }
229
230    fn size_hint(&self) -> (usize, Option<usize>) {
231        self.list.size_hint()
232    }
233}
234
235impl<K, T> FusedStream for StreamMap<K, T>
236where
237    K: Clone + PartialEq + Send + Unpin + 'static,
238    T: Stream + Unpin + Send + 'static,
239{
240    fn is_terminated(&self) -> bool {
241        self.list.is_terminated()
242    }
243}
244
245#[cfg(test)]
246mod test {
247    use crate::stream::StreamMap;
248    use futures::stream::empty;
249    use futures::{Stream, StreamExt};
250    use std::pin::Pin;
251    use std::task::{Context, Poll};
252
253    struct Once<T> {
254        value: T,
255    }
256
257    impl<T> Once<T> {
258        pub fn new(value: T) -> Self {
259            Self { value }
260        }
261
262        pub fn get(&self) -> &T {
263            &self.value
264        }
265
266        pub fn set(&mut self, val: T) {
267            self.value = val;
268        }
269    }
270
271    impl<T> Stream for Once<T>
272    where
273        T: Unpin,
274    {
275        type Item = T;
276        fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277            Poll::Pending
278        }
279    }
280
281    #[test]
282    fn existing_key() {
283        let mut map = StreamMap::new();
284        assert!(map.insert(1, empty::<()>()));
285        assert!(!map.insert(1, empty::<()>()));
286    }
287
288    #[test]
289    fn poll_multiple_keyed_streams() {
290        let mut map = StreamMap::new();
291        map.insert(1, futures::stream::once(async { 10 }).boxed());
292        map.insert(2, futures::stream::once(async { 20 }).boxed());
293
294        map.insert(3, futures::stream::iter(vec![30, 40, 50]).boxed());
295
296        futures::executor::block_on(async move {
297            assert_eq!(map.next().await, Some((1, 10)));
298            assert_eq!(map.next().await, Some((2, 20)));
299            assert_eq!(map.next().await, Some((3, 30)));
300            assert_eq!(map.next().await, Some((3, 40)));
301            assert_eq!(map.next().await, Some((3, 50)));
302            assert_eq!(map.next().await, None);
303            let pending =
304                futures::future::poll_fn(|cx| Poll::Ready(map.poll_next_unpin(cx).is_pending()))
305                    .await;
306            assert!(pending);
307        })
308    }
309
310    #[test]
311    fn get_from_map() {
312        let mut map = StreamMap::new();
313        map.insert(1, Once::new(10));
314        map.insert(2, Once::new(20));
315
316        {
317            let value0 = map.get(&1).expect("valid entry").get();
318            let value1 = map.get(&2).expect("valid entry").get();
319
320            assert_eq!(value0, &10);
321            assert_eq!(value1, &20);
322        }
323
324        {
325            map.get_mut(&1).expect("valid entry").set(100);
326            let value0 = map.get(&1).expect("valid entry").get();
327            assert_eq!(*value0, 100);
328        }
329    }
330}