1use futures::stream::{select_all, StreamExt};
8
9use crate::source::Source;
10
11pub fn merge<T: Send + 'static>(a: Source<T>, b: Source<T>) -> Source<T> {
13 Source { inner: futures::stream::select(a.into_boxed(), b.into_boxed()).boxed() }
14}
15
16pub fn merge_all<T: Send + 'static, I: IntoIterator<Item = Source<T>>>(sources: I) -> Source<T> {
18 let boxed = sources.into_iter().map(|s| s.into_boxed()).collect::<Vec<_>>();
19 Source { inner: select_all(boxed).boxed() }
20}
21
22pub fn concat<T: Send + 'static>(a: Source<T>, b: Source<T>) -> Source<T> {
24 a.concat(b)
25}
26
27pub fn zip<A, B>(a: Source<A>, b: Source<B>) -> Source<(A, B)>
29where
30 A: Send + 'static,
31 B: Send + 'static,
32{
33 Source { inner: a.into_boxed().zip(b.into_boxed()).boxed() }
34}
35
36pub fn zip_with<A, B, C, F>(a: Source<A>, b: Source<B>, mut f: F) -> Source<C>
38where
39 A: Send + 'static,
40 B: Send + 'static,
41 C: Send + 'static,
42 F: FnMut(A, B) -> C + Send + 'static,
43{
44 Source { inner: a.into_boxed().zip(b.into_boxed()).map(move |(x, y)| f(x, y)).boxed() }
45}
46
47pub fn zip_with_index<T: Send + 'static>(source: Source<T>) -> Source<(u64, T)> {
48 Source { inner: source.into_boxed().enumerate().map(|(i, v)| (i as u64, v)).boxed() }
49}
50
51pub fn merge_sorted<T: Ord + Send + 'static>(a: Source<T>, b: Source<T>) -> Source<T> {
55 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<T>();
56 let mut sa = a.into_boxed();
57 let mut sb = b.into_boxed();
58 tokio::spawn(async move {
59 let mut head_a = sa.next().await;
60 let mut head_b = sb.next().await;
61 loop {
62 match (head_a.take(), head_b.take()) {
63 (None, None) => return,
64 (Some(x), None) => {
65 if tx.send(x).is_err() {
66 return;
67 }
68 while let Some(rest) = sa.next().await {
69 if tx.send(rest).is_err() {
70 return;
71 }
72 }
73 return;
74 }
75 (None, Some(y)) => {
76 if tx.send(y).is_err() {
77 return;
78 }
79 while let Some(rest) = sb.next().await {
80 if tx.send(rest).is_err() {
81 return;
82 }
83 }
84 return;
85 }
86 (Some(x), Some(y)) => {
87 if x <= y {
88 if tx.send(x).is_err() {
89 return;
90 }
91 head_b = Some(y);
92 head_a = sa.next().await;
93 } else {
94 if tx.send(y).is_err() {
95 return;
96 }
97 head_a = Some(x);
98 head_b = sb.next().await;
99 }
100 }
101 }
102 }
103 });
104 Source::from_receiver(rx)
105}
106
107pub fn merge_prioritized<T: Send + 'static>(
111 a: Source<T>,
112 weight_a: u32,
113 b: Source<T>,
114 weight_b: u32,
115) -> Source<T> {
116 assert!(weight_a >= 1 && weight_b >= 1, "merge_prioritized weights must be >= 1");
117 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<T>();
118 let mut sa = a.into_boxed();
119 let mut sb = b.into_boxed();
120 tokio::spawn(async move {
121 let mut budget_a = weight_a;
122 let mut budget_b = weight_b;
123 loop {
124 tokio::select! {
125 biased;
126 ax = sa.next(), if budget_a > 0 => match ax {
127 Some(v) => {
128 if tx.send(v).is_err() { return; }
129 budget_a -= 1;
130 if budget_a == 0 && budget_b == 0 {
131 budget_a = weight_a;
132 budget_b = weight_b;
133 }
134 }
135 None => budget_a = 0,
136 },
137 bx = sb.next(), if budget_b > 0 => match bx {
138 Some(v) => {
139 if tx.send(v).is_err() { return; }
140 budget_b -= 1;
141 if budget_a == 0 && budget_b == 0 {
142 budget_a = weight_a;
143 budget_b = weight_b;
144 }
145 }
146 None => budget_b = 0,
147 },
148 else => return,
149 }
150 }
151 });
152 Source::from_receiver(rx)
153}
154
155pub fn broadcast<T>(source: Source<T>) -> (Source<T>, Source<T>)
158where
159 T: Clone + Send + 'static,
160{
161 let (tx_a, rx_a) = tokio::sync::mpsc::unbounded_channel::<T>();
162 let (tx_b, rx_b) = tokio::sync::mpsc::unbounded_channel::<T>();
163 let mut inner = source.into_boxed();
164 tokio::spawn(async move {
165 while let Some(item) = inner.next().await {
166 let _ = tx_a.send(item.clone());
167 let _ = tx_b.send(item);
168 }
169 });
170 (Source::from_receiver(rx_a), Source::from_receiver(rx_b))
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::sink::Sink;
177
178 #[tokio::test]
179 async fn merge_interleaves_two_sources() {
180 let a = Source::from_iter(vec![1, 2, 3]);
181 let b = Source::from_iter(vec![10, 20, 30]);
182 let mut out = Sink::collect(merge(a, b)).await;
183 out.sort();
184 assert_eq!(out, vec![1, 2, 3, 10, 20, 30]);
185 }
186
187 #[tokio::test]
188 async fn zip_pairs_sources() {
189 let out =
190 Sink::collect(zip(Source::from_iter(vec!["a", "b", "c"]), Source::from_iter(vec![1, 2, 3])))
191 .await;
192 assert_eq!(out, vec![("a", 1), ("b", 2), ("c", 3)]);
193 }
194
195 #[tokio::test]
196 async fn zip_with_index_numbers_elements() {
197 let out = Sink::collect(zip_with_index(Source::from_iter(vec!["x", "y"]))).await;
198 assert_eq!(out, vec![(0, "x"), (1, "y")]);
199 }
200
201 #[tokio::test]
202 async fn broadcast_duplicates_elements() {
203 let (a, b) = broadcast(Source::from_iter(vec![1, 2, 3]));
204 let (ra, rb) = tokio::join!(Sink::collect(a), Sink::collect(b));
205 assert_eq!(ra, vec![1, 2, 3]);
206 assert_eq!(rb, vec![1, 2, 3]);
207 }
208}