Skip to main content

rakka_streams/
substream.rs

1//! Substream operators on `Source<T>`.
2//!
3//! Phase 12.1 of `docs/full-port-plan.md`. Akka.NET / Akka Streams
4//! parity: `GroupBy`, `SplitWhen`, `SplitAfter`. We ship the
5//! pragmatic shape: each operator returns a stream of
6//! `(key, Source<T>)` (for `group_by`) or `Source<T>` (for split
7//! variants), buffered through tokio mpsc channels rather than the
8//! materializer-coordinated SubFlow algebra of the JVM port.
9
10use std::collections::HashMap;
11use std::hash::Hash;
12
13use futures::stream::StreamExt;
14use tokio::sync::mpsc;
15
16use crate::source::Source;
17
18/// `group_by(max_substreams, key_fn)` — fan one source into N
19/// per-key substreams. Each new key yields a `(key, Source<T>)`
20/// pair on the returned outer source. Once `max_substreams` keys
21/// are open, additional keys' elements are dropped.
22///
23/// Akka.NET: `Source.GroupBy(maxSubstreams, key)`.
24pub fn group_by<T, K, F>(src: Source<T>, max_substreams: usize, mut key_fn: F) -> Source<(K, Source<T>)>
25where
26    T: Send + 'static,
27    K: Eq + Hash + Clone + Send + 'static,
28    F: FnMut(&T) -> K + Send + 'static,
29{
30    assert!(max_substreams >= 1, "max_substreams must be >= 1");
31    let (outer_tx, outer_rx) = mpsc::unbounded_channel::<(K, Source<T>)>();
32    let mut inner = src.into_boxed();
33    tokio::spawn(async move {
34        let mut substreams: HashMap<K, mpsc::UnboundedSender<T>> = HashMap::new();
35        while let Some(item) = inner.next().await {
36            let key = key_fn(&item);
37            if let Some(tx) = substreams.get(&key) {
38                let _ = tx.send(item);
39                continue;
40            }
41            if substreams.len() >= max_substreams {
42                // Spec-aligned: silently drop new keys past the cap.
43                continue;
44            }
45            let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
46            let _ = sub_tx.send(item);
47            substreams.insert(key.clone(), sub_tx);
48            if outer_tx.send((key, Source::from_receiver(sub_rx))).is_err() {
49                // Outer consumer dropped; abort.
50                return;
51            }
52        }
53        // Upstream complete — drop sub_tx senders so each substream
54        // sees clean termination. Done by HashMap drop.
55    });
56    Source::from_receiver(outer_rx)
57}
58
59/// `split_when(pred)` — split the source into a sequence of
60/// substreams; a new substream begins when `pred(item)` returns true,
61/// with the splitting element going to the **new** substream.
62///
63/// Akka.NET: `Source.SplitWhen(pred)`.
64pub fn split_when<T, F>(src: Source<T>, mut pred: F) -> Source<Source<T>>
65where
66    T: Send + 'static,
67    F: FnMut(&T) -> bool + Send + 'static,
68{
69    let (outer_tx, outer_rx) = mpsc::unbounded_channel::<Source<T>>();
70    let mut inner = src.into_boxed();
71    tokio::spawn(async move {
72        let mut current_tx: Option<mpsc::UnboundedSender<T>> = None;
73        while let Some(item) = inner.next().await {
74            let split = pred(&item);
75            if split || current_tx.is_none() {
76                let (sub_tx, sub_rx) = mpsc::unbounded_channel::<T>();
77                if outer_tx.send(Source::from_receiver(sub_rx)).is_err() {
78                    return;
79                }
80                current_tx = Some(sub_tx);
81            }
82            if let Some(tx) = &current_tx {
83                let _ = tx.send(item);
84            }
85        }
86    });
87    Source::from_receiver(outer_rx)
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::sink::Sink;
94    use std::collections::HashMap;
95
96    #[tokio::test]
97    async fn group_by_partitions_into_substreams_by_key() {
98        let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
99        let outer = group_by(s, 2, |x: &i32| *x % 2);
100        let pairs = Sink::collect(outer).await;
101        let mut by_key: HashMap<i32, Vec<i32>> = HashMap::new();
102        for (k, sub) in pairs {
103            let v = Sink::collect(sub).await;
104            by_key.insert(k, v);
105        }
106        assert_eq!(by_key.get(&0), Some(&vec![2, 4, 6]));
107        assert_eq!(by_key.get(&1), Some(&vec![1, 3, 5]));
108    }
109
110    #[tokio::test]
111    async fn group_by_drops_keys_past_cap() {
112        let s = Source::from_iter(vec![1, 2, 3, 4, 5, 6]);
113        // Cap at 1 — only the first key (=1) gets a substream.
114        let outer = group_by(s, 1, |x: &i32| *x % 3);
115        let pairs = Sink::collect(outer).await;
116        assert_eq!(pairs.len(), 1);
117        let (k, sub) = pairs.into_iter().next().unwrap();
118        assert_eq!(k, 1);
119        let v = Sink::collect(sub).await;
120        assert_eq!(v, vec![1, 4]);
121    }
122
123    #[tokio::test]
124    async fn split_when_starts_new_substream_on_predicate() {
125        let s = Source::from_iter(vec![1, 2, 10, 3, 4, 20, 5]);
126        let outer = split_when(s, |x: &i32| *x >= 10);
127        let subs = Sink::collect(outer).await;
128        let mut chunks = Vec::new();
129        for sub in subs {
130            chunks.push(Sink::collect(sub).await);
131        }
132        assert_eq!(chunks, vec![vec![1, 2], vec![10, 3, 4], vec![20, 5]]);
133    }
134}