rakka_streams/
substream.rs1use std::collections::HashMap;
11use std::hash::Hash;
12
13use futures::stream::StreamExt;
14use tokio::sync::mpsc;
15
16use crate::source::Source;
17
18pub fn group_by<T, K, F>(src: Source<T>, max_substreams: usize, mut key_fn: F) -> Source<(K, Source<T>)>
25where
26 T: Send + 'static,
27 K: Eq + Hash + Clone + Send + 'static,
28 F: FnMut(&T) -> K + Send + 'static,
29{
30 assert!(max_substreams >= 1, "max_substreams must be >= 1");
31 let (outer_tx, outer_rx) = mpsc::unbounded_channel::<(K, Source<T>)>();
32 let mut inner = src.into_boxed();
33 tokio::spawn(async move {
34 let mut substreams: HashMap<K, mpsc::UnboundedSender<T>> = HashMap::new();
35 while let Some(item) = inner.next().await {
36 let key = key_fn(&item);
37 if let Some(tx) = substreams.get(&key) {
38 let _ = tx.send(item);
39 continue;
40 }
41 if substreams.len() >= max_substreams {
42 continue;
44 }
45 let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
46 let _ = sub_tx.send(item);
47 substreams.insert(key.clone(), sub_tx);
48 if outer_tx.send((key, Source::from_receiver(sub_rx))).is_err() {
49 return;
51 }
52 }
53 });
56 Source::from_receiver(outer_rx)
57}
58
59pub fn split_when<T, F>(src: Source<T>, mut pred: F) -> Source<Source<T>>
65where
66 T: Send + 'static,
67 F: FnMut(&T) -> bool + Send + 'static,
68{
69 let (outer_tx, outer_rx) = mpsc::unbounded_channel::<Source<T>>();
70 let mut inner = src.into_boxed();
71 tokio::spawn(async move {
72 let mut current_tx: Option<mpsc::UnboundedSender<T>> = None;
73 while let Some(item) = inner.next().await {
74 let split = pred(&item);
75 if split || current_tx.is_none() {
76 let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
77 if outer_tx.send(Source::from_receiver(sub_rx)).is_err() {
78 return;
79 }
80 current_tx = Some(sub_tx);
81 }
82 if let Some(tx) = ¤t_tx {
83 let _ = tx.send(item);
84 }
85 }
86 });
87 Source::from_receiver(outer_rx)
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::sink::Sink;
94 use std::collections::HashMap;
95
96 #[tokio::test]
97 async fn group_by_partitions_into_substreams_by_key() {
98 let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
99 let outer = group_by(s, 2, |x: &i32| *x % 2);
100 let pairs = Sink::collect(outer).await;
101 let mut by_key: HashMap<i32, Vec<i32>> = HashMap::new();
102 for (k, sub) in pairs {
103 let v = Sink::collect(sub).await;
104 by_key.insert(k, v);
105 }
106 assert_eq!(by_key.get(&0), Some(&vec![2, 4, 6]));
107 assert_eq!(by_key.get(&1), Some(&vec![1, 3, 5]));
108 }
109
110 #[tokio::test]
111 async fn group_by_drops_keys_past_cap() {
112 let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
113 let outer = group_by(s, 1, |x: &i32| *x % 3);
115 let pairs = Sink::collect(outer).await;
116 assert_eq!(pairs.len(), 1);
117 let (k, sub) = pairs.into_iter().next().unwrap();
118 assert_eq!(k, 1);
119 let v = Sink::collect(sub).await;
120 assert_eq!(v, vec![1, 4]);
121 }
122
123 #[tokio::test]
124 async fn split_when_starts_new_substream_on_predicate() {
125 let s = Source::from_iter(vec![1, 2, 10, 3, 4, 20, 5]);
126 let outer = split_when(s, |x: &i32| *x >= 10);
127 let subs = Sink::collect(outer).await;
128 let mut chunks = Vec::new();
129 for sub in subs {
130 chunks.push(Sink::collect(sub).await);
131 }
132 assert_eq!(chunks, vec![vec![1, 2], vec![10, 3, 4], vec![20, 5]]);
133 }
134}