Skip to main content

atomr_streams/
junction.rs

1//! Fan-in and fan-out junctions.
2//!
3//! This port exposes the common linear-composition junctions without the
4//! upstream graph-DSL plumbing: `merge`, `merge_all`, `concat`, `zip`,
5//! `zip_with_index`, and `broadcast` (into two `Source<T>` clones).
6
7use futures::stream::{select_all, StreamExt};
8
9use crate::source::Source;
10
11/// (interleaving, order not guaranteed).
12pub fn merge<T: Send + 'static>(a: Source<T>, b: Source<T>) -> Source<T> {
13    Source { inner: futures::stream::select(a.into_boxed(), b.into_boxed()).boxed() }
14}
15
16/// with arbitrary fan-in.
17pub fn merge_all<T: Send + 'static, I: IntoIterator<Item = Source<T>>>(sources: I) -> Source<T> {
18    let boxed = sources.into_iter().map(|s| s.into_boxed()).collect::<Vec<_>>();
19    Source { inner: select_all(boxed).boxed() }
20}
21
22/// Drain first source fully, then second.
23pub fn concat<T: Send + 'static>(a: Source<T>, b: Source<T>) -> Source<T> {
24    a.concat(b)
25}
26
27/// Pair corresponding elements.
28pub fn zip<A, B>(a: Source<A>, b: Source<B>) -> Source<(A, B)>
29where
30    A: Send + 'static,
31    B: Send + 'static,
32{
33    Source { inner: a.into_boxed().zip(b.into_boxed()).boxed() }
34}
35
36/// Pair corresponding elements and apply `f`.
37pub fn zip_with<A, B, C, F>(a: Source<A>, b: Source<B>, mut f: F) -> Source<C>
38where
39    A: Send + 'static,
40    B: Send + 'static,
41    C: Send + 'static,
42    F: FnMut(A, B) -> C + Send + 'static,
43{
44    Source { inner: a.into_boxed().zip(b.into_boxed()).map(move |(x, y)| f(x, y)).boxed() }
45}
46
47pub fn zip_with_index<T: Send + 'static>(source: Source<T>) -> Source<(u64, T)> {
48    Source { inner: source.into_boxed().enumerate().map(|(i, v)| (i as u64, v)).boxed() }
49}
50
51/// Merge two **already-sorted** sources
52/// preserving total order. Both inputs must be ascending; output is
53/// ascending. Buffers one element per side via tokio mpsc.
54pub fn merge_sorted<T: Ord + Send + 'static>(a: Source<T>, b: Source<T>) -> Source<T> {
55    let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<T>();
56    let mut sa = a.into_boxed();
57    let mut sb = b.into_boxed();
58    tokio::spawn(async move {
59        let mut head_a = sa.next().await;
60        let mut head_b = sb.next().await;
61        loop {
62            match (head_a.take(), head_b.take()) {
63                (None, None) => return,
64                (Some(x), None) => {
65                    if tx.send(x).is_err() {
66                        return;
67                    }
68                    while let Some(rest) = sa.next().await {
69                        if tx.send(rest).is_err() {
70                            return;
71                        }
72                    }
73                    return;
74                }
75                (None, Some(y)) => {
76                    if tx.send(y).is_err() {
77                        return;
78                    }
79                    while let Some(rest) = sb.next().await {
80                        if tx.send(rest).is_err() {
81                            return;
82                        }
83                    }
84                    return;
85                }
86                (Some(x), Some(y)) => {
87                    if x <= y {
88                        if tx.send(x).is_err() {
89                            return;
90                        }
91                        head_b = Some(y);
92                        head_a = sa.next().await;
93                    } else {
94                        if tx.send(y).is_err() {
95                            return;
96                        }
97                        head_a = Some(x);
98                        head_b = sb.next().await;
99                    }
100                }
101            }
102        }
103    });
104    Source::from_receiver(rx)
105}
106
107/// Every input contributes elements in
108/// proportion to its weight when both have items pending, falling
109/// through to whichever side has work otherwise. Weights ≥ 1.
110pub fn merge_prioritized<T: Send + 'static>(
111    a: Source<T>,
112    weight_a: u32,
113    b: Source<T>,
114    weight_b: u32,
115) -> Source<T> {
116    assert!(weight_a >= 1 && weight_b >= 1, "merge_prioritized weights must be >= 1");
117    let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<T>();
118    let mut sa = a.into_boxed();
119    let mut sb = b.into_boxed();
120    tokio::spawn(async move {
121        let mut budget_a = weight_a;
122        let mut budget_b = weight_b;
123        loop {
124            tokio::select! {
125                biased;
126                ax = sa.next(), if budget_a > 0 => match ax {
127                    Some(v) => {
128                        if tx.send(v).is_err() { return; }
129                        budget_a -= 1;
130                        if budget_a == 0 && budget_b == 0 {
131                            budget_a = weight_a;
132                            budget_b = weight_b;
133                        }
134                    }
135                    None => budget_a = 0,
136                },
137                bx = sb.next(), if budget_b > 0 => match bx {
138                    Some(v) => {
139                        if tx.send(v).is_err() { return; }
140                        budget_b -= 1;
141                        if budget_a == 0 && budget_b == 0 {
142                            budget_a = weight_a;
143                            budget_b = weight_b;
144                        }
145                    }
146                    None => budget_b = 0,
147                },
148                else => return,
149            }
150        }
151    });
152    Source::from_receiver(rx)
153}
154
155/// Cheap fan-out into two independent sources
156/// using cloned items and a bounded channel per downstream.
157pub fn broadcast<T>(source: Source<T>) -> (Source<T>, Source<T>)
158where
159    T: Clone + Send + 'static,
160{
161    let (tx_a, rx_a) = tokio::sync::mpsc::unbounded_channel::<T>();
162    let (tx_b, rx_b) = tokio::sync::mpsc::unbounded_channel::<T>();
163    let mut inner = source.into_boxed();
164    tokio::spawn(async move {
165        while let Some(item) = inner.next().await {
166            let _ = tx_a.send(item.clone());
167            let _ = tx_b.send(item);
168        }
169    });
170    (Source::from_receiver(rx_a), Source::from_receiver(rx_b))
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::sink::Sink;
177
178    #[tokio::test]
179    async fn merge_interleaves_two_sources() {
180        let a = Source::from_iter(vec![1, 2, 3]);
181        let b = Source::from_iter(vec![10, 20, 30]);
182        let mut out = Sink::collect(merge(a, b)).await;
183        out.sort();
184        assert_eq!(out, vec![1, 2, 3, 10, 20, 30]);
185    }
186
187    #[tokio::test]
188    async fn zip_pairs_sources() {
189        let out =
190            Sink::collect(zip(Source::from_iter(vec!["a", "b", "c"]), Source::from_iter(vec![1, 2, 3])))
191                .await;
192        assert_eq!(out, vec![("a", 1), ("b", 2), ("c", 3)]);
193    }
194
195    #[tokio::test]
196    async fn zip_with_index_numbers_elements() {
197        let out = Sink::collect(zip_with_index(Source::from_iter(vec!["x", "y"]))).await;
198        assert_eq!(out, vec![(0, "x"), (1, "y")]);
199    }
200
201    #[tokio::test]
202    async fn broadcast_duplicates_elements() {
203        let (a, b) = broadcast(Source::from_iter(vec![1, 2, 3]));
204        let (ra, rb) = tokio::join!(Sink::collect(a), Sink::collect(b));
205        assert_eq!(ra, vec![1, 2, 3]);
206        assert_eq!(rb, vec![1, 2, 3]);
207    }
208}