1use std::collections::HashMap;
8use std::hash::Hash;
9
10use futures::stream::StreamExt;
11use tokio::sync::mpsc;
12
13use crate::source::Source;
14
15pub fn group_by<T, K, F>(src: Source<T>, max_substreams: usize, mut key_fn: F) -> Source<(K, Source<T>)>
21where
22 T: Send + 'static,
23 K: Eq + Hash + Clone + Send + 'static,
24 F: FnMut(&T) -> K + Send + 'static,
25{
26 assert!(max_substreams >= 1, "max_substreams must be >= 1");
27 let (outer_tx, outer_rx) = mpsc::unbounded_channel::<(K, Source<T>)>();
28 let mut inner = src.into_boxed();
29 tokio::spawn(async move {
30 let mut substreams: HashMap<K, mpsc::UnboundedSender<T>> = HashMap::new();
31 while let Some(item) = inner.next().await {
32 let key = key_fn(&item);
33 if let Some(tx) = substreams.get(&key) {
34 let _ = tx.send(item);
35 continue;
36 }
37 if substreams.len() >= max_substreams {
38 continue;
40 }
41 let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
42 let _ = sub_tx.send(item);
43 substreams.insert(key.clone(), sub_tx);
44 if outer_tx.send((key, Source::from_receiver(sub_rx))).is_err() {
45 return;
47 }
48 }
49 });
52 Source::from_receiver(outer_rx)
53}
54
55pub fn split_when<T, F>(src: Source<T>, mut pred: F) -> Source<Source<T>>
60where
61 T: Send + 'static,
62 F: FnMut(&T) -> bool + Send + 'static,
63{
64 let (outer_tx, outer_rx) = mpsc::unbounded_channel::<Source<T>>();
65 let mut inner = src.into_boxed();
66 tokio::spawn(async move {
67 let mut current_tx: Option<mpsc::UnboundedSender<T>> = None;
68 while let Some(item) = inner.next().await {
69 let split = pred(&item);
70 if split || current_tx.is_none() {
71 let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
72 if outer_tx.send(Source::from_receiver(sub_rx)).is_err() {
73 return;
74 }
75 current_tx = Some(sub_tx);
76 }
77 if let Some(tx) = ¤t_tx {
78 let _ = tx.send(item);
79 }
80 }
81 });
82 Source::from_receiver(outer_rx)
83}
84
85pub fn split_after<T, F>(src: Source<T>, mut pred: F) -> Source<Source<T>>
90where
91 T: Send + 'static,
92 F: FnMut(&T) -> bool + Send + 'static,
93{
94 let (outer_tx, outer_rx) = mpsc::unbounded_channel::<Source<T>>();
95 let mut inner = src.into_boxed();
96 tokio::spawn(async move {
97 let mut current_tx: Option<mpsc::UnboundedSender<T>> = None;
98 while let Some(item) = inner.next().await {
99 if current_tx.is_none() {
102 let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
103 if outer_tx.send(Source::from_receiver(sub_rx)).is_err() {
104 return;
105 }
106 current_tx = Some(sub_tx);
107 }
108 let split = pred(&item);
109 if let Some(tx) = ¤t_tx {
110 let _ = tx.send(item);
111 }
112 if split {
113 current_tx = None;
116 }
117 }
118 });
119 Source::from_receiver(outer_rx)
120}
121
122pub fn prefix_and_tail<T>(src: Source<T>, n: usize) -> Source<(Vec<T>, Source<T>)>
129where
130 T: Send + 'static,
131{
132 let (outer_tx, outer_rx) = mpsc::unbounded_channel::<(Vec<T>, Source<T>)>();
133 let mut inner = src.into_boxed();
134 tokio::spawn(async move {
135 let mut prefix = Vec::with_capacity(n);
136 for _ in 0..n {
137 match inner.next().await {
138 Some(it) => prefix.push(it),
139 None => break,
140 }
141 }
142 let (tail_tx, tail_rx) = mpsc::unbounded_channel::<T>();
143 if outer_tx.send((prefix, Source::from_receiver(tail_rx))).is_err() {
144 return;
145 }
146 while let Some(it) = inner.next().await {
147 if tail_tx.send(it).is_err() {
148 break;
149 }
150 }
151 });
152 Source::from_receiver(outer_rx)
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::sink::Sink;
159 use std::collections::HashMap;
160
161 #[tokio::test]
162 async fn group_by_partitions_into_substreams_by_key() {
163 let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
164 let outer = group_by(s, 2, |x: &i32| *x % 2);
165 let pairs = Sink::collect(outer).await;
166 let mut by_key: HashMap<i32, Vec<i32>> = HashMap::new();
167 for (k, sub) in pairs {
168 let v = Sink::collect(sub).await;
169 by_key.insert(k, v);
170 }
171 assert_eq!(by_key.get(&0), Some(&vec![2, 4, 6]));
172 assert_eq!(by_key.get(&1), Some(&vec![1, 3, 5]));
173 }
174
175 #[tokio::test]
176 async fn group_by_drops_keys_past_cap() {
177 let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
178 let outer = group_by(s, 1, |x: &i32| *x % 3);
180 let pairs = Sink::collect(outer).await;
181 assert_eq!(pairs.len(), 1);
182 let (k, sub) = pairs.into_iter().next().unwrap();
183 assert_eq!(k, 1);
184 let v = Sink::collect(sub).await;
185 assert_eq!(v, vec![1, 4]);
186 }
187
188 #[tokio::test]
189 async fn split_when_starts_new_substream_on_predicate() {
190 let s = Source::from_iter(vec![1, 2, 10, 3, 4, 20, 5]);
191 let outer = split_when(s, |x: &i32| *x >= 10);
192 let subs = Sink::collect(outer).await;
193 let mut chunks = Vec::new();
194 for sub in subs {
195 chunks.push(Sink::collect(sub).await);
196 }
197 assert_eq!(chunks, vec![vec![1, 2], vec![10, 3, 4], vec![20, 5]]);
198 }
199
200 #[tokio::test]
201 async fn split_after_keeps_pivot_in_previous_chunk() {
202 let s = Source::from_iter(vec![1, 2, 10, 3, 4, 20, 5]);
203 let outer = split_after(s, |x: &i32| *x >= 10);
204 let subs = Sink::collect(outer).await;
205 let mut chunks = Vec::new();
206 for sub in subs {
207 chunks.push(Sink::collect(sub).await);
208 }
209 assert_eq!(chunks, vec![vec![1, 2, 10], vec![3, 4, 20], vec![5]]);
210 }
211
212 #[tokio::test]
213 async fn prefix_and_tail_returns_first_n_then_rest() {
214 let s = Source::from_iter(vec![1, 2, 3, 4, 5]);
215 let outer = prefix_and_tail(s, 2);
216 let mut pairs = Sink::collect(outer).await;
217 assert_eq!(pairs.len(), 1);
218 let (prefix, tail) = pairs.pop().unwrap();
219 assert_eq!(prefix, vec![1, 2]);
220 let rest = Sink::collect(tail).await;
221 assert_eq!(rest, vec![3, 4, 5]);
222 }
223
224 #[tokio::test]
225 async fn prefix_and_tail_yields_short_prefix_when_source_exhausts() {
226 let s = Source::from_iter(vec![1, 2]);
227 let outer = prefix_and_tail(s, 5);
228 let mut pairs = Sink::collect(outer).await;
229 let (prefix, tail) = pairs.pop().unwrap();
230 assert_eq!(prefix, vec![1, 2]);
231 let rest = Sink::collect(tail).await;
232 assert!(rest.is_empty());
233 }
234}