Skip to main content

atomr_streams/
substream.rs

1//! Substream operators on `Source<T>`.
2//!
3//! Operators: `GroupBy`, `SplitWhen`, `SplitAfter`. Each operator returns a
4//! stream of `(key, Source<T>)` (for `group_by`) or `Source<T>` (for split
5//! variants), buffered through tokio mpsc channels.
6
7use std::collections::HashMap;
8use std::hash::Hash;
9
10use futures::stream::StreamExt;
11use tokio::sync::mpsc;
12
13use crate::source::Source;
14
15/// `group_by(max_substreams, key_fn)` — fan one source into N
16/// per-key substreams. Each new key yields a `(key, Source<T>)`
17/// pair on the returned outer source. Once `max_substreams` keys
18/// are open, additional keys' elements are dropped.
19///
20pub fn group_by<T, K, F>(src: Source<T>, max_substreams: usize, mut key_fn: F) -> Source<(K, Source<T>)>
21where
22    T: Send + 'static,
23    K: Eq + Hash + Clone + Send + 'static,
24    F: FnMut(&T) -> K + Send + 'static,
25{
26    assert!(max_substreams >= 1, "max_substreams must be >= 1");
27    let (outer_tx, outer_rx) = mpsc::unbounded_channel::<(K, Source<T>)>();
28    let mut inner = src.into_boxed();
29    tokio::spawn(async move {
30        let mut substreams: HashMap<K, mpsc::UnboundedSender<T>> = HashMap::new();
31        while let Some(item) = inner.next().await {
32            let key = key_fn(&item);
33            if let Some(tx) = substreams.get(&key) {
34                let _ = tx.send(item);
35                continue;
36            }
37            if substreams.len() >= max_substreams {
38                // Spec-aligned: silently drop new keys past the cap.
39                continue;
40            }
41            let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
42            let _ = sub_tx.send(item);
43            substreams.insert(key.clone(), sub_tx);
44            if outer_tx.send((key, Source::from_receiver(sub_rx))).is_err() {
45                // Outer consumer dropped; abort.
46                return;
47            }
48        }
49        // Upstream complete — drop sub_tx senders so each substream
50        // sees clean termination. Done by HashMap drop.
51    });
52    Source::from_receiver(outer_rx)
53}
54
55/// `split_when(pred)` — split the source into a sequence of
56/// substreams; a new substream begins when `pred(item)` returns true,
57/// with the splitting element going to the **new** substream.
58///
59pub fn split_when<T, F>(src: Source<T>, mut pred: F) -> Source<Source<T>>
60where
61    T: Send + 'static,
62    F: FnMut(&T) -> bool + Send + 'static,
63{
64    let (outer_tx, outer_rx) = mpsc::unbounded_channel::<Source<T>>();
65    let mut inner = src.into_boxed();
66    tokio::spawn(async move {
67        let mut current_tx: Option<mpsc::UnboundedSender<T>> = None;
68        while let Some(item) = inner.next().await {
69            let split = pred(&item);
70            if split || current_tx.is_none() {
71                let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
72                if outer_tx.send(Source::from_receiver(sub_rx)).is_err() {
73                    return;
74                }
75                current_tx = Some(sub_tx);
76            }
77            if let Some(tx) = &current_tx {
78                let _ = tx.send(item);
79            }
80        }
81    });
82    Source::from_receiver(outer_rx)
83}
84
85/// `split_after(pred)` — like `split_when`, except the splitting
86/// element stays with the **previous** substream and the next element
87/// starts a new one.
88///
89pub fn split_after<T, F>(src: Source<T>, mut pred: F) -> Source<Source<T>>
90where
91    T: Send + 'static,
92    F: FnMut(&T) -> bool + Send + 'static,
93{
94    let (outer_tx, outer_rx) = mpsc::unbounded_channel::<Source<T>>();
95    let mut inner = src.into_boxed();
96    tokio::spawn(async move {
97        let mut current_tx: Option<mpsc::UnboundedSender<T>> = None;
98        while let Some(item) = inner.next().await {
99            // Open a new substream lazily on the first element or
100            // immediately after a previous split-end.
101            if current_tx.is_none() {
102                let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
103                if outer_tx.send(Source::from_receiver(sub_rx)).is_err() {
104                    return;
105                }
106                current_tx = Some(sub_tx);
107            }
108            let split = pred(&item);
109            if let Some(tx) = &current_tx {
110                let _ = tx.send(item);
111            }
112            if split {
113                // End the current substream; the next element starts
114                // a fresh one.
115                current_tx = None;
116            }
117        }
118    });
119    Source::from_receiver(outer_rx)
120}
121
122/// `prefix_and_tail(n)` — return the first `n` elements as a `Vec`
123/// alongside a `Source<T>` carrying the rest.
124///
125/// The single-shot result is
126/// delivered as the only element of the returned source so it composes
127/// uniformly with downstream operators.
128pub fn prefix_and_tail<T>(src: Source<T>, n: usize) -> Source<(Vec<T>, Source<T>)>
129where
130    T: Send + 'static,
131{
132    let (outer_tx, outer_rx) = mpsc::unbounded_channel::<(Vec<T>, Source<T>)>();
133    let mut inner = src.into_boxed();
134    tokio::spawn(async move {
135        let mut prefix = Vec::with_capacity(n);
136        for _ in 0..n {
137            match inner.next().await {
138                Some(it) => prefix.push(it),
139                None => break,
140            }
141        }
142        let (tail_tx, tail_rx) = mpsc::unbounded_channel::<T>();
143        if outer_tx.send((prefix, Source::from_receiver(tail_rx))).is_err() {
144            return;
145        }
146        while let Some(it) = inner.next().await {
147            if tail_tx.send(it).is_err() {
148                break;
149            }
150        }
151    });
152    Source::from_receiver(outer_rx)
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::sink::Sink;
159    use std::collections::HashMap;
160
161    #[tokio::test]
162    async fn group_by_partitions_into_substreams_by_key() {
163        let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
164        let outer = group_by(s, 2, |x: &i32| *x % 2);
165        let pairs = Sink::collect(outer).await;
166        let mut by_key: HashMap<i32, Vec<i32>> = HashMap::new();
167        for (k, sub) in pairs {
168            let v = Sink::collect(sub).await;
169            by_key.insert(k, v);
170        }
171        assert_eq!(by_key.get(&0), Some(&vec![2, 4, 6]));
172        assert_eq!(by_key.get(&1), Some(&vec![1, 3, 5]));
173    }
174
175    #[tokio::test]
176    async fn group_by_drops_keys_past_cap() {
177        let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
178        // Cap at 1 — only the first key (=1) gets a substream.
179        let outer = group_by(s, 1, |x: &i32| *x % 3);
180        let pairs = Sink::collect(outer).await;
181        assert_eq!(pairs.len(), 1);
182        let (k, sub) = pairs.into_iter().next().unwrap();
183        assert_eq!(k, 1);
184        let v = Sink::collect(sub).await;
185        assert_eq!(v, vec![1, 4]);
186    }
187
188    #[tokio::test]
189    async fn split_when_starts_new_substream_on_predicate() {
190        let s = Source::from_iter(vec![1, 2, 10, 3, 4, 20, 5]);
191        let outer = split_when(s, |x: &i32| *x >= 10);
192        let subs = Sink::collect(outer).await;
193        let mut chunks = Vec::new();
194        for sub in subs {
195            chunks.push(Sink::collect(sub).await);
196        }
197        assert_eq!(chunks, vec![vec![1, 2], vec![10, 3, 4], vec![20, 5]]);
198    }
199
200    #[tokio::test]
201    async fn split_after_keeps_pivot_in_previous_chunk() {
202        let s = Source::from_iter(vec![1, 2, 10, 3, 4, 20, 5]);
203        let outer = split_after(s, |x: &i32| *x >= 10);
204        let subs = Sink::collect(outer).await;
205        let mut chunks = Vec::new();
206        for sub in subs {
207            chunks.push(Sink::collect(sub).await);
208        }
209        assert_eq!(chunks, vec![vec![1, 2, 10], vec![3, 4, 20], vec![5]]);
210    }
211
212    #[tokio::test]
213    async fn prefix_and_tail_returns_first_n_then_rest() {
214        let s = Source::from_iter(vec![1, 2, 3, 4, 5]);
215        let outer = prefix_and_tail(s, 2);
216        let mut pairs = Sink::collect(outer).await;
217        assert_eq!(pairs.len(), 1);
218        let (prefix, tail) = pairs.pop().unwrap();
219        assert_eq!(prefix, vec![1, 2]);
220        let rest = Sink::collect(tail).await;
221        assert_eq!(rest, vec![3, 4, 5]);
222    }
223
224    #[tokio::test]
225    async fn prefix_and_tail_yields_short_prefix_when_source_exhausts() {
226        let s = Source::from_iter(vec![1, 2]);
227        let outer = prefix_and_tail(s, 5);
228        let mut pairs = Sink::collect(outer).await;
229        let (prefix, tail) = pairs.pop().unwrap();
230        assert_eq!(prefix, vec![1, 2]);
231        let rest = Sink::collect(tail).await;
232        assert!(rest.is_empty());
233    }
234}