par_stream/
pull.rs

1use crate::{common::*, config::BufSize, rt, utils};
2use flume::r#async::RecvStream;
3
4/// The builder forwards each stream item according to its key to a destination receiver.
5pub struct PullBuilder<St, K, F, Q = K>
6where
7    St: ?Sized + Stream,
8{
9    buf_size: Option<usize>,
10    key_fn: F,
11    senders: HashMap<K, flume::Sender<St::Item>>,
12    _phantom: PhantomData<Q>,
13    stream: St,
14}
15
16impl<St, K, Q, F> PullBuilder<St, K, F, Q>
17where
18    St: 'static + Send + Stream,
19    St::Item: 'static + Send,
20    F: 'static + Send + FnMut(&St::Item) -> Q,
21    K: 'static + Send + Hash + Eq + Borrow<Q>,
22    Q: Send + Hash + Eq,
23{
24    /// Creates the builder.
25    ///
26    /// The `buf_size` sets the channel size for each registered receiver.
27    /// The `key_fn` is used to compute the key for each input item.
28    pub fn new<B>(stream: St, buf_size: B, key_fn: F) -> Self
29    where
30        B: Into<BufSize>,
31    {
32        let buf_size = buf_size.into().get();
33
34        Self {
35            buf_size,
36            key_fn,
37            senders: HashMap::new(),
38            _phantom: PhantomData,
39            stream,
40        }
41    }
42
43    /// Creates a receiver binding to the `key`.
44    ///
45    /// If the `key` is already registered, it returns `None`.
46    pub fn register(&mut self, key: K) -> Option<RecvStream<'static, St::Item>> {
47        use std::collections::hash_map::Entry as E;
48
49        if let E::Vacant(entry) = self.senders.entry(key) {
50            let (tx, rx) = utils::channel(self.buf_size);
51            entry.insert(tx);
52            Some(rx.into_stream())
53        } else {
54            None
55        }
56    }
57
58    /// Finish the builder and start forwarding items to receivers.
59    ///
60    /// It returns a special leaking receiver that accepts items which
61    /// key is not registered or the destination receiver is closed.
62    pub fn build(self) -> RecvStream<'static, St::Item> {
63        let Self {
64            mut key_fn,
65            senders,
66            stream,
67            buf_size,
68            ..
69        } = self;
70        let (leak_tx, leak_rx) = utils::channel(buf_size);
71
72        rt::spawn(async move {
73            let mut stream = stream.boxed();
74
75            while let Some(item) = stream.next().await {
76                let query = key_fn(&item);
77                let tx = senders.get(&query);
78
79                if let Some(tx) = tx {
80                    if let Err(err) = tx.send_async(item).await {
81                        let _ = leak_tx.send_async(err.into_inner()).await;
82                    }
83                } else {
84                    let _ = leak_tx.send_async(item).await;
85                }
86            }
87        });
88
89        leak_rx.into_stream()
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::{par_stream::ParStreamExt as _, utils::async_test};
97
98    async_test! {
99        async fn pull_routing_test() {
100            let mut builder = stream::iter([("A", 1), ("B", 2), ("C", 3), ("D", 4)])
101                .pull_routing(None, |&(key, _)| key);
102
103            let stream_a = builder.register("A").unwrap();
104            let stream_b = builder.register("B").unwrap();
105            let stream_c = builder.register("C").unwrap();
106            let stream_leak = builder.build();
107
108            let join: Vec<Vec<_>> = future::join_all([
109                stream_a.collect(),
110                stream_b.collect(),
111                stream_c.collect(),
112                stream_leak.collect(),
113            ])
114                .await;
115
116            assert_eq!(
117                join,
118                vec![
119                    vec![("A", 1)],
120                    vec![("B", 2)],
121                    vec![("C", 3)],
122                    vec![("D", 4)]
123                ]
124            );
125        }
126    }
127}