1use crate::{common::*, config::BufSize, rt, utils};
2use flume::r#async::RecvStream;
3
4pub struct PullBuilder<St, K, F, Q = K>
6where
7 St: ?Sized + Stream,
8{
9 buf_size: Option<usize>,
10 key_fn: F,
11 senders: HashMap<K, flume::Sender<St::Item>>,
12 _phantom: PhantomData<Q>,
13 stream: St,
14}
15
16impl<St, K, Q, F> PullBuilder<St, K, F, Q>
17where
18 St: 'static + Send + Stream,
19 St::Item: 'static + Send,
20 F: 'static + Send + FnMut(&St::Item) -> Q,
21 K: 'static + Send + Hash + Eq + Borrow<Q>,
22 Q: Send + Hash + Eq,
23{
24 pub fn new<B>(stream: St, buf_size: B, key_fn: F) -> Self
29 where
30 B: Into<BufSize>,
31 {
32 let buf_size = buf_size.into().get();
33
34 Self {
35 buf_size,
36 key_fn,
37 senders: HashMap::new(),
38 _phantom: PhantomData,
39 stream,
40 }
41 }
42
43 pub fn register(&mut self, key: K) -> Option<RecvStream<'static, St::Item>> {
47 use std::collections::hash_map::Entry as E;
48
49 if let E::Vacant(entry) = self.senders.entry(key) {
50 let (tx, rx) = utils::channel(self.buf_size);
51 entry.insert(tx);
52 Some(rx.into_stream())
53 } else {
54 None
55 }
56 }
57
58 pub fn build(self) -> RecvStream<'static, St::Item> {
63 let Self {
64 mut key_fn,
65 senders,
66 stream,
67 buf_size,
68 ..
69 } = self;
70 let (leak_tx, leak_rx) = utils::channel(buf_size);
71
72 rt::spawn(async move {
73 let mut stream = stream.boxed();
74
75 while let Some(item) = stream.next().await {
76 let query = key_fn(&item);
77 let tx = senders.get(&query);
78
79 if let Some(tx) = tx {
80 if let Err(err) = tx.send_async(item).await {
81 let _ = leak_tx.send_async(err.into_inner()).await;
82 }
83 } else {
84 let _ = leak_tx.send_async(item).await;
85 }
86 }
87 });
88
89 leak_rx.into_stream()
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::{par_stream::ParStreamExt as _, utils::async_test};
97
98 async_test! {
99 async fn pull_routing_test() {
100 let mut builder = stream::iter([("A", 1), ("B", 2), ("C", 3), ("D", 4)])
101 .pull_routing(None, |&(key, _)| key);
102
103 let stream_a = builder.register("A").unwrap();
104 let stream_b = builder.register("B").unwrap();
105 let stream_c = builder.register("C").unwrap();
106 let stream_leak = builder.build();
107
108 let join: Vec<Vec<_>> = future::join_all([
109 stream_a.collect(),
110 stream_b.collect(),
111 stream_c.collect(),
112 stream_leak.collect(),
113 ])
114 .await;
115
116 assert_eq!(
117 join,
118 vec![
119 vec![("A", 1)],
120 vec![("B", 2)],
121 vec![("C", 3)],
122 vec![("D", 4)]
123 ]
124 );
125 }
126 }
127}