Skip to main content

atomr_streams/
sink.rs

1//! Sink — consumes a `Source`, produces a materialized value. akka.net: `Dsl/Sink.cs`.
2//!
3//! Each factory here returns a future that drives the source to completion and
4//! produces the materialized value. These wrappers mirror the most common
5//! Akka.Streams sinks (`Fold`, `Aggregate`, `Sum`, `First`, `Last`, `Seq`,
6//! `ForEach`, `Ignore`) and add a lightweight `SinkQueue`.
7
8use std::future::Future;
9use std::sync::Arc;
10use std::time::Duration;
11
12use futures::stream::StreamExt;
13use parking_lot::Mutex;
14use tokio::sync::Notify;
15
16use crate::source::Source;
17
18pub struct Sink;
19
20impl Sink {
21    /// akka.net: `Fold` — drive the source and accumulate a single value.
22    pub async fn fold<T, Acc, F>(source: Source<T>, init: Acc, mut f: F) -> Acc
23    where
24        T: Send + 'static,
25        Acc: Send + 'static,
26        F: FnMut(Acc, T) -> Acc + Send + 'static,
27    {
28        source.into_boxed().fold(init, move |acc, x| futures::future::ready(f(acc, x))).await
29    }
30
31    /// akka.net: `AggregateAsync` — async fold.
32    pub async fn fold_async<T, Acc, F, Fut>(source: Source<T>, init: Acc, f: F) -> Acc
33    where
34        T: Send + 'static,
35        Acc: Send + 'static,
36        F: FnMut(Acc, T) -> Fut + Send + 'static,
37        Fut: Future<Output = Acc> + Send + 'static,
38    {
39        source.into_boxed().fold(init, f).await
40    }
41
42    /// akka.net: `Sink.Seq` — collect into a Vec.
43    pub async fn collect<T>(source: Source<T>) -> Vec<T>
44    where
45        T: Send + 'static,
46    {
47        source.into_boxed().collect().await
48    }
49
50    /// akka.net: `Sink.First`.
51    pub async fn first<T>(source: Source<T>) -> Option<T>
52    where
53        T: Send + 'static,
54    {
55        source.into_boxed().next().await
56    }
57
58    /// akka.net: `Sink.Last`.
59    pub async fn last<T>(source: Source<T>) -> Option<T>
60    where
61        T: Send + 'static,
62    {
63        source.into_boxed().fold(None, |_, x| async move { Some(x) }).await
64    }
65
66    /// akka.net: `Sink.Sum`.
67    pub async fn sum<T>(source: Source<T>) -> T
68    where
69        T: Send + Default + std::ops::Add<Output = T> + 'static,
70    {
71        let init: T = T::default();
72        Self::fold(source, init, |acc, x| acc + x).await
73    }
74
75    /// akka.net: `Sink.Count`.
76    pub async fn count<T>(source: Source<T>) -> u64
77    where
78        T: Send + 'static,
79    {
80        Self::fold(source, 0u64, |acc, _| acc + 1).await
81    }
82
83    pub async fn for_each<T, F>(source: Source<T>, mut f: F)
84    where
85        T: Send + 'static,
86        F: FnMut(T) + Send + 'static,
87    {
88        source
89            .into_boxed()
90            .for_each(move |x| {
91                f(x);
92                futures::future::ready(())
93            })
94            .await
95    }
96
97    /// akka.net: `Sink.ForEachAsync`.
98    pub async fn for_each_async<T, F, Fut>(source: Source<T>, parallelism: usize, f: F)
99    where
100        T: Send + 'static,
101        F: FnMut(T) -> Fut + Send + 'static,
102        Fut: Future<Output = ()> + Send + 'static,
103    {
104        let p = parallelism.max(1);
105        source.into_boxed().for_each_concurrent(p, f).await
106    }
107
108    /// akka.net: `Sink.Ignore`.
109    pub async fn ignore<T: Send + 'static>(source: Source<T>) {
110        source.into_boxed().for_each(|_| futures::future::ready(())).await
111    }
112
113    /// Send each element to an `UnboundedSender`. akka.net: `Sink.ActorRef`
114    /// (atomr equivalent uses an mpsc channel).
115    pub async fn to_sender<T>(source: Source<T>, tx: tokio::sync::mpsc::UnboundedSender<T>)
116    where
117        T: Send + 'static,
118    {
119        let mut stream = source.into_boxed();
120        while let Some(v) = stream.next().await {
121            if tx.send(v).is_err() {
122                break;
123            }
124        }
125    }
126
127    /// akka.net: `Sink.Queue` — run the source and expose a pull-based API.
128    /// The returned `SinkQueue::pull` future returns `Ok(Some(t))` per element,
129    /// `Ok(None)` after the stream completes.
130    pub fn queue<T>(source: Source<T>) -> SinkQueue<T>
131    where
132        T: Send + 'static,
133    {
134        let buf: Arc<Mutex<SinkQueueState<T>>> = Arc::new(Mutex::new(SinkQueueState::default()));
135        let notify = Arc::new(Notify::new());
136        let buf_t = Arc::clone(&buf);
137        let notify_t = Arc::clone(&notify);
138        let handle = tokio::spawn(async move {
139            let mut stream = source.into_boxed();
140            while let Some(v) = stream.next().await {
141                buf_t.lock().items.push_back(v);
142                notify_t.notify_one();
143            }
144            buf_t.lock().complete = true;
145            notify_t.notify_waiters();
146        });
147        SinkQueue { buf, notify, _handle: handle }
148    }
149
150    /// `Sink.Queue` with a bounded element timeout per pull.
151    pub async fn pull_with_timeout<T: Send + 'static>(q: &SinkQueue<T>, t: Duration) -> Option<T> {
152        tokio::time::timeout(t, q.pull()).await.ok().flatten()
153    }
154}
155
156struct SinkQueueState<T> {
157    items: std::collections::VecDeque<T>,
158    complete: bool,
159}
160
161impl<T> Default for SinkQueueState<T> {
162    fn default() -> Self {
163        Self { items: std::collections::VecDeque::new(), complete: false }
164    }
165}
166
167pub struct SinkQueue<T> {
168    buf: Arc<Mutex<SinkQueueState<T>>>,
169    notify: Arc<Notify>,
170    _handle: tokio::task::JoinHandle<()>,
171}
172
173impl<T: Send + 'static> SinkQueue<T> {
174    /// Pull the next element, awaiting as long as the source is still running.
175    /// Returns `None` once the source completes.
176    pub async fn pull(&self) -> Option<T> {
177        loop {
178            {
179                let mut guard = self.buf.lock();
180                if let Some(v) = guard.items.pop_front() {
181                    return Some(v);
182                }
183                if guard.complete {
184                    return None;
185                }
186            }
187            self.notify.notified().await;
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[tokio::test]
197    async fn first_last_sum_count() {
198        assert_eq!(Sink::first(Source::from_iter(vec![1, 2, 3])).await, Some(1));
199        assert_eq!(Sink::last(Source::from_iter(vec![1, 2, 3])).await, Some(3));
200        assert_eq!(Sink::sum(Source::from_iter(1..=10_i32)).await, 55);
201        assert_eq!(Sink::count(Source::from_iter(0..42_u64)).await, 42);
202    }
203
204    #[tokio::test]
205    async fn for_each_async_runs_all_tasks() {
206        let sum = std::sync::Arc::new(std::sync::Mutex::new(0i32));
207        let sum_c = sum.clone();
208        Sink::for_each_async(Source::from_iter(1..=5), 2, move |v| {
209            let sum_c = sum_c.clone();
210            async move {
211                *sum_c.lock().unwrap() += v;
212            }
213        })
214        .await;
215        assert_eq!(*sum.lock().unwrap(), 15);
216    }
217
218    #[tokio::test]
219    async fn sink_queue_pulls_until_complete() {
220        let q = Sink::queue(Source::from_iter(vec![10, 20, 30]));
221        assert_eq!(q.pull().await, Some(10));
222        assert_eq!(q.pull().await, Some(20));
223        assert_eq!(q.pull().await, Some(30));
224        assert_eq!(q.pull().await, None);
225    }
226}