Skip to main content

atomr_streams/
sink.rs

1//! Sink — consumes a `Source`, produces a materialized value.
2//!
3//! Each factory here returns a future that drives the source to completion and
4//! produces the materialized value. These wrappers mirror the most common
5//! Sinks (`Fold`, `Aggregate`, `Sum`, `First`, `Last`, `Seq`,
6//! `ForEach`, `Ignore`) and add a lightweight `SinkQueue`.
7
8use std::future::Future;
9use std::sync::Arc;
10use std::time::Duration;
11
12use futures::stream::StreamExt;
13use parking_lot::Mutex;
14use tokio::sync::Notify;
15
16use crate::source::Source;
17
18pub struct Sink;
19
20impl Sink {
21    /// Drive the source and accumulate a single value.
22    pub async fn fold<T, Acc, F>(source: Source<T>, init: Acc, mut f: F) -> Acc
23    where
24        T: Send + 'static,
25        Acc: Send + 'static,
26        F: FnMut(Acc, T) -> Acc + Send + 'static,
27    {
28        source.into_boxed().fold(init, move |acc, x| futures::future::ready(f(acc, x))).await
29    }
30
31    /// Async fold.
32    pub async fn fold_async<T, Acc, F, Fut>(source: Source<T>, init: Acc, f: F) -> Acc
33    where
34        T: Send + 'static,
35        Acc: Send + 'static,
36        F: FnMut(Acc, T) -> Fut + Send + 'static,
37        Fut: Future<Output = Acc> + Send + 'static,
38    {
39        source.into_boxed().fold(init, f).await
40    }
41
42    /// Collect into a Vec.
43    pub async fn collect<T>(source: Source<T>) -> Vec<T>
44    where
45        T: Send + 'static,
46    {
47        source.into_boxed().collect().await
48    }
49
50    pub async fn first<T>(source: Source<T>) -> Option<T>
51    where
52        T: Send + 'static,
53    {
54        source.into_boxed().next().await
55    }
56
57    pub async fn last<T>(source: Source<T>) -> Option<T>
58    where
59        T: Send + 'static,
60    {
61        source.into_boxed().fold(None, |_, x| async move { Some(x) }).await
62    }
63
64    pub async fn sum<T>(source: Source<T>) -> T
65    where
66        T: Send + Default + std::ops::Add<Output = T> + 'static,
67    {
68        let init: T = T::default();
69        Self::fold(source, init, |acc, x| acc + x).await
70    }
71
72    pub async fn count<T>(source: Source<T>) -> u64
73    where
74        T: Send + 'static,
75    {
76        Self::fold(source, 0u64, |acc, _| acc + 1).await
77    }
78
79    pub async fn for_each<T, F>(source: Source<T>, mut f: F)
80    where
81        T: Send + 'static,
82        F: FnMut(T) + Send + 'static,
83    {
84        source
85            .into_boxed()
86            .for_each(move |x| {
87                f(x);
88                futures::future::ready(())
89            })
90            .await
91    }
92
93    pub async fn for_each_async<T, F, Fut>(source: Source<T>, parallelism: usize, f: F)
94    where
95        T: Send + 'static,
96        F: FnMut(T) -> Fut + Send + 'static,
97        Fut: Future<Output = ()> + Send + 'static,
98    {
99        let p = parallelism.max(1);
100        source.into_boxed().for_each_concurrent(p, f).await
101    }
102
103    pub async fn ignore<T: Send + 'static>(source: Source<T>) {
104        source.into_boxed().for_each(|_| futures::future::ready(())).await
105    }
106
107    /// Send each element to an `UnboundedSender`.
108    /// (atomr equivalent uses an mpsc channel).
109    pub async fn to_sender<T>(source: Source<T>, tx: tokio::sync::mpsc::UnboundedSender<T>)
110    where
111        T: Send + 'static,
112    {
113        let mut stream = source.into_boxed();
114        while let Some(v) = stream.next().await {
115            if tx.send(v).is_err() {
116                break;
117            }
118        }
119    }
120
121    /// Run the source and expose a pull-based API.
122    /// The returned `SinkQueue::pull` future returns `Ok(Some(t))` per element,
123    /// `Ok(None)` after the stream completes.
124    pub fn queue<T>(source: Source<T>) -> SinkQueue<T>
125    where
126        T: Send + 'static,
127    {
128        let buf: Arc<Mutex<SinkQueueState<T>>> = Arc::new(Mutex::new(SinkQueueState::default()));
129        let notify = Arc::new(Notify::new());
130        let buf_t = Arc::clone(&buf);
131        let notify_t = Arc::clone(&notify);
132        let handle = tokio::spawn(async move {
133            let mut stream = source.into_boxed();
134            while let Some(v) = stream.next().await {
135                buf_t.lock().items.push_back(v);
136                notify_t.notify_one();
137            }
138            buf_t.lock().complete = true;
139            notify_t.notify_waiters();
140        });
141        SinkQueue { buf, notify, _handle: handle }
142    }
143
144    /// `Sink.Queue` with a bounded element timeout per pull.
145    pub async fn pull_with_timeout<T: Send + 'static>(q: &SinkQueue<T>, t: Duration) -> Option<T> {
146        tokio::time::timeout(t, q.pull()).await.ok().flatten()
147    }
148}
149
150struct SinkQueueState<T> {
151    items: std::collections::VecDeque<T>,
152    complete: bool,
153}
154
155impl<T> Default for SinkQueueState<T> {
156    fn default() -> Self {
157        Self { items: std::collections::VecDeque::new(), complete: false }
158    }
159}
160
161pub struct SinkQueue<T> {
162    buf: Arc<Mutex<SinkQueueState<T>>>,
163    notify: Arc<Notify>,
164    _handle: tokio::task::JoinHandle<()>,
165}
166
167impl<T: Send + 'static> SinkQueue<T> {
168    /// Pull the next element, awaiting as long as the source is still running.
169    /// Returns `None` once the source completes.
170    pub async fn pull(&self) -> Option<T> {
171        loop {
172            {
173                let mut guard = self.buf.lock();
174                if let Some(v) = guard.items.pop_front() {
175                    return Some(v);
176                }
177                if guard.complete {
178                    return None;
179                }
180            }
181            self.notify.notified().await;
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[tokio::test]
191    async fn first_last_sum_count() {
192        assert_eq!(Sink::first(Source::from_iter(vec![1, 2, 3])).await, Some(1));
193        assert_eq!(Sink::last(Source::from_iter(vec![1, 2, 3])).await, Some(3));
194        assert_eq!(Sink::sum(Source::from_iter(1..=10_i32)).await, 55);
195        assert_eq!(Sink::count(Source::from_iter(0..42_u64)).await, 42);
196    }
197
198    #[tokio::test]
199    async fn for_each_async_runs_all_tasks() {
200        let sum = std::sync::Arc::new(std::sync::Mutex::new(0i32));
201        let sum_c = sum.clone();
202        Sink::for_each_async(Source::from_iter(1..=5), 2, move |v| {
203            let sum_c = sum_c.clone();
204            async move {
205                *sum_c.lock().unwrap() += v;
206            }
207        })
208        .await;
209        assert_eq!(*sum.lock().unwrap(), 15);
210    }
211
212    #[tokio::test]
213    async fn sink_queue_pulls_until_complete() {
214        let q = Sink::queue(Source::from_iter(vec![10, 20, 30]));
215        assert_eq!(q.pull().await, Some(10));
216        assert_eq!(q.pull().await, Some(20));
217        assert_eq!(q.pull().await, Some(30));
218        assert_eq!(q.pull().await, None);
219    }
220}